Преглед на файлове

修改训练保存权重参数,解决导出onnx失败的问题

liyan преди 7 месеца
родител
ревизия
2d157c9dff
променени са 1 файла, в които са добавени 1 реда и са изтрити 1 реда
  1. 1 1
      train_alexnet.py

+ 1 - 1
train_alexnet.py

@@ -86,7 +86,7 @@ def train_model(args, train_data, val_data):
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
         os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),  # Save weights as alexnet_{epoch}.h5
-        save_weights_only=True,
+        save_weights_only=False,
         save_freq='epoch',  # Save after every epoch
         verbose=1
     )