@@ -105,9 +105,10 @@ def train_model(args, train_data, val_data):
# Define ModelCheckpoint callback to save weights for each epoch
checkpoint_callback = ModelCheckpoint(
- os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'), # Save weights as alexnet_{epoch}.h5
+ filepath=os.path.join(args.output_dir, 'alexnet_loss_{val_loss:.4f}_{epoch:03d}.h5'),
save_weights_only=False,
save_freq='epoch', # Save after every epoch
+ monitor='val_loss', # Monitor the validation loss
verbose=1
)
@@ -85,8 +85,9 @@ def train_model(args, train_generator, val_generator):
- os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
+ os.path.join(args.output_dir, 'vgg16_loss_{val_loss:.4f}_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5