|
@@ -38,12 +38,18 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
|
# Normalize the datasets (rescale pixel values to [0, 1])
|
|
|
train_dataset = train_dataset.map(
|
|
|
lambda x, y: (train_datagen(x) / 255.0, y),
|
|
|
+ num_parallel_calls=tf.data.AUTOTUNE
|
|
|
)
|
|
|
|
|
|
val_dataset = val_dataset.map(
|
|
|
lambda x, y: (x / 255.0, y),
|
|
|
+ num_parallel_calls=tf.data.AUTOTUNE
|
|
|
)
|
|
|
|
|
|
+ # Prefetch to improve performance
|
|
|
+ train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
|
|
|
+ val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
|
|
|
+
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
|