|
@@ -86,7 +86,7 @@ def train_model(args, train_data, val_data):
|
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'), # Save weights as alexnet_{epoch}.h5
|
|
|
- save_weights_only=True,
|
|
|
+ save_weights_only=False,
|
|
|
save_freq='epoch', # Save after every epoch
|
|
|
verbose=1
|
|
|
)
|