train_vgg16.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. import tensorflow as tf
  3. from keras.callbacks import ModelCheckpoint, CSVLogger
  4. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  5. from models.VGG16 import create_model
  6. def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
  7. # 使用 ImageDataGenerator 加载图像并进行预处理
  8. train_datagen = ImageDataGenerator(
  9. rescale=1.0 / 255.0, # 归一化
  10. rotation_range=40,
  11. width_shift_range=0.2,
  12. height_shift_range=0.2,
  13. shear_range=0.2,
  14. zoom_range=0.2,
  15. horizontal_flip=True,
  16. fill_mode='nearest'
  17. )
  18. val_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
  19. train_generator = train_datagen.flow_from_directory(
  20. train_dir,
  21. target_size=img_size, # VGG16 输入大小
  22. batch_size=batch_size,
  23. class_mode='categorical' # 多分类任务
  24. )
  25. val_generator = val_datagen.flow_from_directory(
  26. val_dir,
  27. target_size=img_size,
  28. batch_size=batch_size,
  29. class_mode='categorical'
  30. )
  31. return train_generator, val_generator
  32. def train_model(args, train_generator, val_generator):
  33. # Create model
  34. model = create_model()
  35. # 调整学习率
  36. learning_rate = args.lr if args.lr else 1e-2
  37. # Select optimizer based on args.opt
  38. if args.opt == 'sgd':
  39. optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
  40. momentum=args.momentum if args.momentum else 0.0)
  41. elif args.opt == 'adam':
  42. optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
  43. else:
  44. optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # Default to Adam if unspecified
  45. # 编译模型
  46. model.compile(optimizer=optimizer,
  47. loss='categorical_crossentropy',
  48. metrics=['accuracy'])
  49. # Define CSVLogger to log training history to a CSV file
  50. csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
  51. # Define ModelCheckpoint callback to save weights for each epoch
  52. checkpoint_callback = ModelCheckpoint(
  53. os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
  54. save_weights_only=False,
  55. save_freq='epoch', # Save after every epoch
  56. verbose=1
  57. )
  58. # 训练模型
  59. history = model.fit(
  60. train_generator,
  61. steps_per_epoch=train_generator.samples // train_generator.batch_size,
  62. epochs=args.epochs,
  63. validation_data=val_generator,
  64. validation_steps=val_generator.samples // val_generator.batch_size,
  65. callbacks=[csv_logger, checkpoint_callback]
  66. )
  67. return history
  68. def get_args_parser(add_help=True):
  69. import argparse
  70. parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
  71. parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path")
  72. parser.add_argument("--output-dir", default="checkpoints/vgg16", type=str, help="path to save outputs")
  73. parser.add_argument(
  74. "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  75. )
  76. parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
  77. parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
  78. parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
  79. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  80. parser.add_argument(
  81. "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
  82. )
  83. return parser
  84. if __name__ == "__main__":
  85. args = get_args_parser().parse_args()
  86. # Set directories for your custom dataset
  87. train_dir = os.path.join(args.data_path, "train")
  88. val_dir = os.path.join(args.data_path, "val")
  89. # Set the directory where you want to save weights
  90. os.makedirs(args.output_dir, exist_ok=True)
  91. # Load data
  92. train_generator, val_generator = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
  93. # Start training
  94. train_model(args, train_generator, val_generator)