|
@@ -85,8 +85,9 @@ def train_model(args, train_generator, val_generator):
|
|
|
|
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
- os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
|
|
|
+ os.path.join(args.output_dir, 'vgg16_loss_{val_loss:.4f}_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
|
|
|
save_weights_only=False,
|
|
|
+ monitor='val_loss', # Monitor the validation loss
|
|
|
save_freq='epoch', # Save after every epoch
|
|
|
verbose=1
|
|
|
)
|