train_vgg16.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. import tensorflow as tf
  3. from keras.callbacks import ModelCheckpoint, CSVLogger
  4. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  5. from models.VGG16 import create_model
  6. def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
  7. # 给定的均值和标准差
  8. mean = tf.constant([0.485, 0.456, 0.406])
  9. std = tf.constant([0.229, 0.224, 0.225])
  10. # 自定义标准化函数
  11. def normalize_image(x):
  12. return (x - mean) / std # 标准化
  13. # 使用 ImageDataGenerator 加载图像并进行预处理
  14. train_datagen = ImageDataGenerator(
  15. rescale=1.0 / 255.0, # 归一化
  16. rotation_range=40,
  17. width_shift_range=0.2,
  18. height_shift_range=0.2,
  19. shear_range=0.2,
  20. zoom_range=0.2,
  21. horizontal_flip=True,
  22. fill_mode='nearest',
  23. preprocessing_function=normalize_image # 使用自定义的标准化函数
  24. )
  25. val_datagen = ImageDataGenerator(
  26. rescale=1.0 / 255.0,
  27. preprocessing_function=normalize_image # 使用自定义的标准化函数
  28. )
  29. train_generator = train_datagen.flow_from_directory(
  30. train_dir,
  31. target_size=img_size, # VGG16 输入大小
  32. batch_size=batch_size,
  33. class_mode='categorical' # 多分类任务
  34. )
  35. val_generator = val_datagen.flow_from_directory(
  36. val_dir,
  37. target_size=img_size,
  38. batch_size=batch_size,
  39. class_mode='categorical'
  40. )
  41. return train_generator, val_generator
  42. def find_latest_checkpoint(directory):
  43. # 获取指定目录下的所有 .h5 文件
  44. checkpoint_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
  45. if not checkpoint_files:
  46. return None
  47. # 按照文件名中的数字进行排序,找到最新的 epoch 文件
  48. checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
  49. return os.path.join(directory, checkpoint_files[-1])
  50. def train_model(args, train_generator, val_generator):
  51. # Create model
  52. model = create_model()
  53. # 调整学习率
  54. learning_rate = args.lr if args.lr else 1e-2
  55. # Select optimizer based on args.opt
  56. if args.opt == 'sgd':
  57. optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
  58. momentum=args.momentum if args.momentum else 0.0)
  59. elif args.opt == 'adam':
  60. optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
  61. else:
  62. optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # Default to Adam if unspecified
  63. # 编译模型
  64. model.compile(optimizer=optimizer,
  65. loss='categorical_crossentropy',
  66. metrics=['accuracy'])
  67. # Check if a checkpoint exists and determine the initial_epoch
  68. latest_checkpoint = find_latest_checkpoint(args.output_dir)
  69. if latest_checkpoint:
  70. model.load_weights(latest_checkpoint) # Load the weights from the checkpoint
  71. initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # Get the last epoch from filename
  72. print(f"Resuming training from epoch {initial_epoch}")
  73. else:
  74. initial_epoch = 0
  75. print("No checkpoint found. Starting training from scratch.")
  76. # Define CSVLogger to log training history to a CSV file
  77. csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
  78. # Define ModelCheckpoint callback to save weights for each epoch
  79. checkpoint_callback = ModelCheckpoint(
  80. os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
  81. save_weights_only=False,
  82. monitor='val_loss', # Monitor the validation loss
  83. save_freq='epoch', # Save after every epoch
  84. verbose=1
  85. )
  86. # 训练模型
  87. history = model.fit(
  88. train_generator,
  89. steps_per_epoch=train_generator.samples // train_generator.batch_size,
  90. epochs=args.epochs,
  91. validation_data=val_generator,
  92. validation_steps=val_generator.samples // val_generator.batch_size,
  93. initial_epoch=initial_epoch,
  94. callbacks=[csv_logger, checkpoint_callback]
  95. )
  96. return history
  97. def get_args_parser(add_help=True):
  98. import argparse
  99. parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
  100. parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path")
  101. parser.add_argument("--output-dir", default="checkpoints/vgg16", type=str, help="path to save outputs")
  102. parser.add_argument(
  103. "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  104. )
  105. parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
  106. parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
  107. parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
  108. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  109. parser.add_argument(
  110. "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
  111. )
  112. return parser
  113. if __name__ == "__main__":
  114. args = get_args_parser().parse_args()
  115. # Set directories for your custom dataset
  116. train_dir = os.path.join(args.data_path, "train")
  117. val_dir = os.path.join(args.data_path, "val")
  118. # Set the directory where you want to save weights
  119. os.makedirs(args.output_dir, exist_ok=True)
  120. # Load data
  121. train_generator, val_generator = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size),
  122. batch_size=args.batch_size)
  123. # Start training
  124. train_model(args, train_generator, val_generator)