Browse Source

修改训练保存权重参数,解决导出onnx失败的问题

liyan 7 months ago
parent
commit
2d157c9dff
1 changed files with 1 additions and 1 deletions
  1. 1 1
      train_alexnet.py

+ 1 - 1
train_alexnet.py

@@ -86,7 +86,7 @@ def train_model(args, train_data, val_data):
     # Define ModelCheckpoint callback to save weights for each epoch
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
     checkpoint_callback = ModelCheckpoint(
         os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),  # Save weights as alexnet_{epoch}.h5
         os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),  # Save weights as alexnet_{epoch}.h5
-        save_weights_only=True,
+        save_weights_only=False,
         save_freq='epoch',  # Save after every epoch
         save_freq='epoch',  # Save after every epoch
         verbose=1
         verbose=1
     )
     )