import os import tensorflow as tf from keras.callbacks import ModelCheckpoint, CSVLogger from tensorflow.keras.preprocessing.image import ImageDataGenerator from models.VGG16 import create_model def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32): # 给定的均值和标准差 mean = tf.constant([0.485, 0.456, 0.406]) std = tf.constant([0.229, 0.224, 0.225]) # 自定义标准化函数 def normalize_image(x): return (x - mean) / std # 标准化 # 使用 ImageDataGenerator 加载图像并进行预处理 train_datagen = ImageDataGenerator( rescale=1.0 / 255.0, # 归一化 rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest', preprocessing_function=normalize_image # 使用自定义的标准化函数 ) val_datagen = ImageDataGenerator( rescale=1.0 / 255.0, preprocessing_function=normalize_image # 使用自定义的标准化函数 ) train_generator = train_datagen.flow_from_directory( train_dir, target_size=img_size, # VGG16 输入大小 batch_size=batch_size, class_mode='categorical' # 多分类任务 ) val_generator = val_datagen.flow_from_directory( val_dir, target_size=img_size, batch_size=batch_size, class_mode='categorical' ) return train_generator, val_generator def find_latest_checkpoint(directory): # 获取指定目录下的所有 .h5 文件 checkpoint_files = [f for f in os.listdir(directory) if f.endswith('.h5')] if not checkpoint_files: return None # 按照文件名中的数字进行排序,找到最新的 epoch 文件 checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0])) return os.path.join(directory, checkpoint_files[-1]) def train_model(args, train_generator, val_generator): # Create model model = create_model() # 调整学习率 learning_rate = args.lr if args.lr else 1e-2 # Select optimizer based on args.opt if args.opt == 'sgd': optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=args.momentum if args.momentum else 0.0) elif args.opt == 'adam': optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) else: optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # Default to Adam if unspecified # 编译模型 model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) # Check if a checkpoint exists and determine the initial_epoch latest_checkpoint = find_latest_checkpoint(args.output_dir) if latest_checkpoint: model.load_weights(latest_checkpoint) # Load the weights from the checkpoint initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # Get the last epoch from filename print(f"Resuming training from epoch {initial_epoch}") else: initial_epoch = 0 print("No checkpoint found. Starting training from scratch.") # Define CSVLogger to log training history to a CSV file csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True) # Define ModelCheckpoint callback to save weights for each epoch checkpoint_callback = ModelCheckpoint( os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5 save_weights_only=False, monitor='val_loss', # Monitor the validation loss save_freq='epoch', # Save after every epoch verbose=1 ) # 训练模型 history = model.fit( train_generator, steps_per_epoch=train_generator.samples // train_generator.batch_size, epochs=args.epochs, validation_data=val_generator, validation_steps=val_generator.samples // val_generator.batch_size, initial_epoch=initial_epoch, callbacks=[csv_logger, checkpoint_callback] ) return history def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path") parser.add_argument("--output-dir", default="checkpoints/vgg16", type=str, help="path to save outputs") parser.add_argument( "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" ) parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") parser.add_argument("--opt", default="sgd", type=str, help="optimizer") parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate") parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") parser.add_argument( "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) return parser if __name__ == "__main__": args = get_args_parser().parse_args() # Set directories for your custom dataset train_dir = os.path.join(args.data_path, "train") val_dir = os.path.join(args.data_path, "val") # Set the directory where you want to save weights os.makedirs(args.output_dir, exist_ok=True) # Load data train_generator, val_generator = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size) # Start training train_model(args, train_generator, val_generator)