Browse Source

模型训练支持断点续训

liyan 7 months ago
parent
commit
ea2ca2f6de
2 changed files with 34 additions and 1 deletions
  1. 13 1
      train_alexnet.py
  2. 21 0
      train_vgg16.py

+ 13 - 1
train_alexnet.py

@@ -61,6 +61,16 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     return train_dataset, val_dataset
 
 
+def find_latest_checkpoint(directory):
+    # 获取指定目录下的所有 .h5 文件
+    checkpoint_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
+    if not checkpoint_files:
+        return None
+    # 按照文件名中的数字进行排序,找到最新的 epoch 文件
+    checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
+    return os.path.join(directory, checkpoint_files[-1])
+
+
 def train_model(args, train_data, val_data):
     # Create model
     model = create_model()
@@ -81,12 +91,14 @@ def train_model(args, train_data, val_data):
     model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
 
     # Check if a checkpoint exists and determine the initial_epoch
-    latest_checkpoint = tf.train.latest_checkpoint(args.output_dir)
+    latest_checkpoint = find_latest_checkpoint(args.output_dir)
     if latest_checkpoint:
+        model.load_weights(latest_checkpoint)  # Load the weights from the checkpoint
         initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0])  # Get the last epoch from filename
         print(f"Resuming training from epoch {initial_epoch}")
     else:
         initial_epoch = 0
+        print("No checkpoint found. Starting training from scratch.")
 
     # Define CSVLogger to log training history to a CSV file
     csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)

+ 21 - 0
train_vgg16.py

@@ -39,6 +39,16 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     return train_generator, val_generator
 
 
+def find_latest_checkpoint(directory):
+    # 获取指定目录下的所有 .h5 文件
+    checkpoint_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
+    if not checkpoint_files:
+        return None
+    # 按照文件名中的数字进行排序,找到最新的 epoch 文件
+    checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
+    return os.path.join(directory, checkpoint_files[-1])
+
+
 def train_model(args, train_generator, val_generator):
     # Create model
     model = create_model()
@@ -60,6 +70,16 @@ def train_model(args, train_generator, val_generator):
                   loss='categorical_crossentropy',
                   metrics=['accuracy'])
 
+    # Check if a checkpoint exists and determine the initial_epoch
+    latest_checkpoint = find_latest_checkpoint(args.output_dir)
+    if latest_checkpoint:
+        model.load_weights(latest_checkpoint)  # Load the weights from the checkpoint
+        initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0])  # Get the last epoch from filename
+        print(f"Resuming training from epoch {initial_epoch}")
+    else:
+        initial_epoch = 0
+        print("No checkpoint found. Starting training from scratch.")
+
     # Define CSVLogger to log training history to a CSV file
     csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
 
@@ -78,6 +98,7 @@ def train_model(args, train_generator, val_generator):
         epochs=args.epochs,
         validation_data=val_generator,
         validation_steps=val_generator.samples // val_generator.batch_size,
+        initial_epoch=initial_epoch,
         callbacks=[csv_logger, checkpoint_callback]
     )