|
@@ -39,6 +39,16 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
return train_generator, val_generator
|
|
return train_generator, val_generator
|
|
|
|
|
|
|
|
|
|
|
|
+def find_latest_checkpoint(directory):
|
|
|
|
+ # 获取指定目录下的所有 .h5 文件
|
|
|
|
+ checkpoint_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
|
|
|
|
+ if not checkpoint_files:
|
|
|
|
+ return None
|
|
|
|
+ # 按照文件名中的数字进行排序,找到最新的 epoch 文件
|
|
|
|
+ checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
|
|
|
|
+ return os.path.join(directory, checkpoint_files[-1])
|
|
|
|
+
|
|
|
|
+
|
|
def train_model(args, train_generator, val_generator):
|
|
def train_model(args, train_generator, val_generator):
|
|
# Create model
|
|
# Create model
|
|
model = create_model()
|
|
model = create_model()
|
|
@@ -60,6 +70,16 @@ def train_model(args, train_generator, val_generator):
|
|
loss='categorical_crossentropy',
|
|
loss='categorical_crossentropy',
|
|
metrics=['accuracy'])
|
|
metrics=['accuracy'])
|
|
|
|
|
|
|
|
+ # Check if a checkpoint exists and determine the initial_epoch
|
|
|
|
+ latest_checkpoint = find_latest_checkpoint(args.output_dir)
|
|
|
|
+ if latest_checkpoint:
|
|
|
|
+ model.load_weights(latest_checkpoint) # Load the weights from the checkpoint
|
|
|
|
+ initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # Get the last epoch from filename
|
|
|
|
+ print(f"Resuming training from epoch {initial_epoch}")
|
|
|
|
+ else:
|
|
|
|
+ initial_epoch = 0
|
|
|
|
+ print("No checkpoint found. Starting training from scratch.")
|
|
|
|
+
|
|
# Define CSVLogger to log training history to a CSV file
|
|
# Define CSVLogger to log training history to a CSV file
|
|
csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
|
|
csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
|
|
|
|
|
|
@@ -78,6 +98,7 @@ def train_model(args, train_generator, val_generator):
|
|
epochs=args.epochs,
|
|
epochs=args.epochs,
|
|
validation_data=val_generator,
|
|
validation_data=val_generator,
|
|
validation_steps=val_generator.samples // val_generator.batch_size,
|
|
validation_steps=val_generator.samples // val_generator.batch_size,
|
|
|
|
+ initial_epoch=initial_epoch,
|
|
callbacks=[csv_logger, checkpoint_callback]
|
|
callbacks=[csv_logger, checkpoint_callback]
|
|
)
|
|
)
|
|
|
|
|