|
@@ -9,13 +9,6 @@ from tensorflow.keras.preprocessing import image_dataset_from_directory
|
|
|
|
|
|
|
|
|
|
def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
- # Define data augmentation for the training set
|
|
|
|
- # train_datagen = tf.keras.Sequential([
|
|
|
|
- # tf.keras.layers.RandomFlip('horizontal'),
|
|
|
|
- # tf.keras.layers.RandomRotation(0.2),
|
|
|
|
- # tf.keras.layers.RandomZoom(0.2),
|
|
|
|
- # tf.keras.layers.RandomContrast(0.2),
|
|
|
|
- # ])
|
|
|
|
def augment(image):
|
|
def augment(image):
|
|
# Random horizontal flip
|
|
# Random horizontal flip
|
|
image = tf.image.random_flip_left_right(image)
|
|
image = tf.image.random_flip_left_right(image)
|
|
@@ -28,29 +21,33 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
# Load training dataset
|
|
# Load training dataset
|
|
train_dataset = image_dataset_from_directory(
|
|
train_dataset = image_dataset_from_directory(
|
|
train_dir,
|
|
train_dir,
|
|
- image_size=img_size, # Resize images to (224, 224)
|
|
|
|
|
|
+ image_size=img_size,
|
|
batch_size=batch_size,
|
|
batch_size=batch_size,
|
|
- label_mode='categorical', # Return integer labels
|
|
|
|
|
|
+ label_mode='categorical',
|
|
shuffle=True
|
|
shuffle=True
|
|
)
|
|
)
|
|
|
|
|
|
# Load validation dataset
|
|
# Load validation dataset
|
|
val_dataset = image_dataset_from_directory(
|
|
val_dataset = image_dataset_from_directory(
|
|
val_dir,
|
|
val_dir,
|
|
- image_size=img_size, # Resize images to (224, 224)
|
|
|
|
|
|
+ image_size=img_size,
|
|
batch_size=batch_size,
|
|
batch_size=batch_size,
|
|
- label_mode='categorical', # Return integer labels
|
|
|
|
|
|
+ label_mode='categorical',
|
|
shuffle=False
|
|
shuffle=False
|
|
)
|
|
)
|
|
|
|
|
|
- # Normalize the datasets (rescale pixel values to [0, 1])
|
|
|
|
|
|
+ # Define mean and std for standardization (ImageNet values)
|
|
|
|
+ mean = tf.constant([0.485, 0.456, 0.406])
|
|
|
|
+ std = tf.constant([0.229, 0.224, 0.225])
|
|
|
|
+
|
|
|
|
+ # Normalize and standardize the datasets
|
|
train_dataset = train_dataset.map(
|
|
train_dataset = train_dataset.map(
|
|
- lambda x, y: (augment(x) / 255.0, y),
|
|
|
|
|
|
+ lambda x, y: ((augment(x) / 255.0 - mean) / std, y),
|
|
num_parallel_calls=tf.data.AUTOTUNE
|
|
num_parallel_calls=tf.data.AUTOTUNE
|
|
)
|
|
)
|
|
|
|
|
|
val_dataset = val_dataset.map(
|
|
val_dataset = val_dataset.map(
|
|
- lambda x, y: (x / 255.0, y),
|
|
|
|
|
|
+ lambda x, y: ((x / 255.0 - mean) / std, y),
|
|
num_parallel_calls=tf.data.AUTOTUNE
|
|
num_parallel_calls=tf.data.AUTOTUNE
|
|
)
|
|
)
|
|
|
|
|
|
@@ -105,7 +102,7 @@ def train_model(args, train_data, val_data):
|
|
|
|
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
checkpoint_callback = ModelCheckpoint(
|
|
checkpoint_callback = ModelCheckpoint(
|
|
- filepath=os.path.join(args.output_dir, 'alexnet_loss_{val_loss:.4f}_{epoch:03d}.h5'),
|
|
|
|
|
|
+ filepath=os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),
|
|
save_weights_only=False,
|
|
save_weights_only=False,
|
|
save_freq='epoch', # Save after every epoch
|
|
save_freq='epoch', # Save after every epoch
|
|
monitor='val_loss', # Monitor the validation loss
|
|
monitor='val_loss', # Monitor the validation loss
|