@@ -86,7 +86,7 @@ def train_model(args, train_data, val_data):
# Define ModelCheckpoint callback to save weights for each epoch
checkpoint_callback = ModelCheckpoint(
os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'), # Save weights as alexnet_{epoch}.h5
- save_weights_only=True,
+ save_weights_only=False,
save_freq='epoch', # Save after every epoch
verbose=1
)