train_alexnet.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import os
  2. import tensorflow as tf
  3. from keras.optimizers import Adam, SGD
  4. from keras.callbacks import ModelCheckpoint, CSVLogger
  5. from models.AlexNet import create_model
  6. from tensorflow.keras.preprocessing import image_dataset_from_directory
  7. def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
  8. # Define data augmentation for the training set
  9. # train_datagen = tf.keras.Sequential([
  10. # tf.keras.layers.RandomFlip('horizontal'),
  11. # tf.keras.layers.RandomRotation(0.2),
  12. # tf.keras.layers.RandomZoom(0.2),
  13. # tf.keras.layers.RandomContrast(0.2),
  14. # ])
  15. def augment(image):
  16. # Random horizontal flip
  17. image = tf.image.random_flip_left_right(image)
  18. # Random contrast adjustment
  19. image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
  20. # Random brightness adjustment
  21. image = tf.image.random_brightness(image, max_delta=0.2)
  22. return image
  23. # Load training dataset
  24. train_dataset = image_dataset_from_directory(
  25. train_dir,
  26. image_size=img_size, # Resize images to (224, 224)
  27. batch_size=batch_size,
  28. label_mode='categorical', # Return integer labels
  29. shuffle=True
  30. )
  31. # Load validation dataset
  32. val_dataset = image_dataset_from_directory(
  33. val_dir,
  34. image_size=img_size, # Resize images to (224, 224)
  35. batch_size=batch_size,
  36. label_mode='categorical', # Return integer labels
  37. shuffle=False
  38. )
  39. # Normalize the datasets (rescale pixel values to [0, 1])
  40. train_dataset = train_dataset.map(
  41. lambda x, y: (augment(x) / 255.0, y),
  42. num_parallel_calls=tf.data.AUTOTUNE
  43. )
  44. val_dataset = val_dataset.map(
  45. lambda x, y: (x / 255.0, y),
  46. num_parallel_calls=tf.data.AUTOTUNE
  47. )
  48. # Prefetch to improve performance
  49. train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
  50. val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
  51. return train_dataset, val_dataset
  52. def find_latest_checkpoint(directory):
  53. # 获取指定目录下的所有 .h5 文件
  54. checkpoint_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
  55. if not checkpoint_files:
  56. return None
  57. # 按照文件名中的数字进行排序,找到最新的 epoch 文件
  58. checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
  59. return os.path.join(directory, checkpoint_files[-1])
  60. def train_model(args, train_data, val_data):
  61. # Create model
  62. model = create_model()
  63. # 调整学习率
  64. learning_rate = args.lr if args.lr else 1e-2
  65. # Select optimizer based on args.opt
  66. if args.opt == 'sgd':
  67. optimizer = SGD(learning_rate=learning_rate,
  68. momentum=args.momentum if args.momentum else 0.0)
  69. elif args.opt == 'adam':
  70. optimizer = Adam(learning_rate=learning_rate)
  71. else:
  72. optimizer = Adam(learning_rate=learning_rate) # Default to Adam if unspecified
  73. # Compile model
  74. model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
  75. # Check if a checkpoint exists and determine the initial_epoch
  76. latest_checkpoint = find_latest_checkpoint(args.output_dir)
  77. if latest_checkpoint:
  78. model.load_weights(latest_checkpoint) # Load the weights from the checkpoint
  79. initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # Get the last epoch from filename
  80. print(f"Resuming training from epoch {initial_epoch}")
  81. else:
  82. initial_epoch = 0
  83. print("No checkpoint found. Starting training from scratch.")
  84. # Define CSVLogger to log training history to a CSV file
  85. csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
  86. # Define ModelCheckpoint callback to save weights for each epoch
  87. checkpoint_callback = ModelCheckpoint(
  88. os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'), # Save weights as alexnet_{epoch}.h5
  89. save_weights_only=False,
  90. save_freq='epoch', # Save after every epoch
  91. verbose=1
  92. )
  93. # Train the model
  94. history = model.fit(
  95. train_data,
  96. epochs=args.epochs,
  97. validation_data=val_data,
  98. initial_epoch=initial_epoch,
  99. callbacks=[csv_logger, checkpoint_callback], # Add checkpoint callback
  100. )
  101. return history
  102. def get_args_parser(add_help=True):
  103. import argparse
  104. parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
  105. parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path")
  106. parser.add_argument("--output-dir", default="checkpoints/alexnet", type=str, help="path to save outputs")
  107. parser.add_argument(
  108. "-b", "--batch-size", default=64, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  109. )
  110. parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
  111. parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
  112. parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
  113. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  114. parser.add_argument(
  115. "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
  116. )
  117. return parser
  118. if __name__ == "__main__":
  119. args = get_args_parser().parse_args()
  120. # Set directories for your custom dataset
  121. train_dir = os.path.join(args.data_path, "train")
  122. val_dir = os.path.join(args.data_path, "val")
  123. # Set the directory where you want to save weights
  124. os.makedirs(args.output_dir, exist_ok=True)
  125. # Load data
  126. train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
  127. # Start training
  128. train_model(args, train_data, val_data)