ソースを参照

训练脚本权重文件名添加验证集损失

liyan 7 ヶ月 前
コミット
12e1533de1
2 ファイル変更4 行追加2 行削除
  1. 2 1
      train_alexnet.py
  2. 2 1
      train_vgg16.py

+ 2 - 1
train_alexnet.py

@@ -105,9 +105,10 @@ def train_model(args, train_data, val_data):
 
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
-        os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),  # Save weights as alexnet_{epoch}.h5
+        filepath=os.path.join(args.output_dir, 'alexnet_loss_{val_loss:.4f}_{epoch:03d}.h5'),
         save_weights_only=False,
         save_freq='epoch',  # Save after every epoch
+        monitor='val_loss',  # Monitor the validation loss
         verbose=1
     )
 

+ 2 - 1
train_vgg16.py

@@ -85,8 +85,9 @@ def train_model(args, train_generator, val_generator):
 
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
-        os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'),  # Save weights as vgg16_{epoch}.h5
+        os.path.join(args.output_dir, 'vgg16_loss_{val_loss:.4f}_{epoch:03d}.h5'),  # Save weights as vgg16_{epoch}.h5
         save_weights_only=False,
+        monitor='val_loss',  # Monitor the validation loss
         save_freq='epoch',  # Save after every epoch
         verbose=1
     )