train_alexnet.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os
  2. import tensorflow as tf
  3. from keras.optimizers import Adam, SGD
  4. from keras.callbacks import ModelCheckpoint, CSVLogger
  5. from models.AlexNet import create_model
  6. from tensorflow.keras.preprocessing import image_dataset_from_directory
  7. def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
  8. # Define data augmentation for the training set
  9. train_datagen = tf.keras.Sequential([
  10. tf.keras.layers.RandomFlip('horizontal'),
  11. tf.keras.layers.RandomRotation(0.2),
  12. tf.keras.layers.RandomZoom(0.2),
  13. tf.keras.layers.RandomContrast(0.2),
  14. ])
  15. # Load training dataset
  16. train_dataset = image_dataset_from_directory(
  17. train_dir,
  18. image_size=img_size, # Resize images to (224, 224)
  19. batch_size=batch_size,
  20. label_mode='categorical', # Return integer labels
  21. shuffle=True
  22. )
  23. # Load validation dataset
  24. val_dataset = image_dataset_from_directory(
  25. val_dir,
  26. image_size=img_size, # Resize images to (224, 224)
  27. batch_size=batch_size,
  28. label_mode='categorical', # Return integer labels
  29. shuffle=False
  30. )
  31. # Normalize the datasets (rescale pixel values to [0, 1])
  32. train_dataset = train_dataset.map(
  33. lambda x, y: (train_datagen(x) / 255.0, y),
  34. )
  35. val_dataset = val_dataset.map(
  36. lambda x, y: (x / 255.0, y),
  37. )
  38. return train_dataset, val_dataset
  39. def train_model(args, train_data, val_data):
  40. # Create model
  41. model = create_model()
  42. # 调整学习率
  43. learning_rate = args.lr if args.lr else 1e-2
  44. # optimizer = SGD(learning_rate=learning_rate, momentum=args.momentum)
  45. # Compile model
  46. model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
  47. # Check if a checkpoint exists and determine the initial_epoch
  48. latest_checkpoint = tf.train.latest_checkpoint(args.output_dir)
  49. if latest_checkpoint:
  50. initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # Get the last epoch from filename
  51. print(f"Resuming training from epoch {initial_epoch}")
  52. else:
  53. initial_epoch = 0
  54. # Define CSVLogger to log training history to a CSV file
  55. csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
  56. # Define ModelCheckpoint callback to save weights for each epoch
  57. checkpoint_callback = ModelCheckpoint(
  58. os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'), # Save weights as alexnet_{epoch}.h5
  59. save_weights_only=True,
  60. save_freq='epoch', # Save after every epoch
  61. verbose=1
  62. )
  63. # Train the model
  64. history = model.fit(
  65. train_data,
  66. epochs=args.epochs,
  67. validation_data=val_data,
  68. initial_epoch=initial_epoch,
  69. callbacks=[csv_logger, checkpoint_callback], # Add checkpoint callback
  70. )
  71. return history
  72. def get_args_parser(add_help=True):
  73. import argparse
  74. parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
  75. parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path")
  76. parser.add_argument("--output-dir", default="checkpoints/alexnet", type=str, help="path to save outputs")
  77. parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
  78. parser.add_argument(
  79. "-b", "--batch-size", default=64, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  80. )
  81. parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
  82. parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
  83. parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
  84. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  85. parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
  86. parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
  87. parser.add_argument(
  88. "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
  89. )
  90. parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
  91. parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
  92. parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
  93. parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
  94. parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
  95. parser.add_argument(
  96. "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
  97. )
  98. return parser
  99. if __name__ == "__main__":
  100. args = get_args_parser().parse_args()
  101. # Set directories for your custom dataset
  102. train_dir = os.path.join(args.data_path, "train")
  103. val_dir = os.path.join(args.data_path, "val")
  104. # Set the directory where you want to save weights
  105. os.makedirs(args.output_dir, exist_ok=True)
  106. # Load data
  107. train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
  108. # Start training
  109. train_model(args, train_data, val_data)