|
@@ -0,0 +1,139 @@
|
|
|
+import os
|
|
|
+
|
|
|
+import tensorflow as tf
|
|
|
+from keras.optimizers import Adam, SGD
|
|
|
+from keras.callbacks import ModelCheckpoint, CSVLogger
|
|
|
+from models.AlexNet import create_model
|
|
|
+from tensorflow.keras.preprocessing import image_dataset_from_directory
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
|
+ # Define data augmentation for the training set
|
|
|
+ train_datagen = tf.keras.Sequential([
|
|
|
+ tf.keras.layers.RandomFlip('horizontal'),
|
|
|
+ tf.keras.layers.RandomRotation(0.2),
|
|
|
+ tf.keras.layers.RandomZoom(0.2),
|
|
|
+ tf.keras.layers.RandomContrast(0.2),
|
|
|
+ ])
|
|
|
+
|
|
|
+ # Load training dataset
|
|
|
+ train_dataset = image_dataset_from_directory(
|
|
|
+ train_dir,
|
|
|
+ image_size=img_size, # Resize images to (224, 224)
|
|
|
+ batch_size=batch_size,
|
|
|
+ label_mode='categorical', # Return integer labels
|
|
|
+ shuffle=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # Load validation dataset
|
|
|
+ val_dataset = image_dataset_from_directory(
|
|
|
+ val_dir,
|
|
|
+ image_size=img_size, # Resize images to (224, 224)
|
|
|
+ batch_size=batch_size,
|
|
|
+ label_mode='categorical', # Return integer labels
|
|
|
+ shuffle=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # Normalize the datasets (rescale pixel values to [0, 1])
|
|
|
+ train_dataset = train_dataset.map(
|
|
|
+ lambda x, y: (train_datagen(x) / 255.0, y),
|
|
|
+ )
|
|
|
+
|
|
|
+ val_dataset = val_dataset.map(
|
|
|
+ lambda x, y: (x / 255.0, y),
|
|
|
+ )
|
|
|
+
|
|
|
+ return train_dataset, val_dataset
|
|
|
+
|
|
|
+
|
|
|
+def train_model(args, train_data, val_data):
|
|
|
+ # Create model
|
|
|
+ model = create_model()
|
|
|
+
|
|
|
+ # 调整学习率
|
|
|
+ learning_rate = args.lr if args.lr else 1e-2
|
|
|
+ # optimizer = SGD(learning_rate=learning_rate, momentum=args.momentum)
|
|
|
+
|
|
|
+ # Compile model
|
|
|
+ model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
|
|
|
+
|
|
|
+ # Check if a checkpoint exists and determine the initial_epoch
|
|
|
+ latest_checkpoint = tf.train.latest_checkpoint(args.output_dir)
|
|
|
+ if latest_checkpoint:
|
|
|
+ initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # Get the last epoch from filename
|
|
|
+ print(f"Resuming training from epoch {initial_epoch}")
|
|
|
+ else:
|
|
|
+ initial_epoch = 0
|
|
|
+
|
|
|
+ # Define CSVLogger to log training history to a CSV file
|
|
|
+ csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
|
|
|
+
|
|
|
+ # Define ModelCheckpoint callback to save weights for each epoch
|
|
|
+ checkpoint_callback = ModelCheckpoint(
|
|
|
+ os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'), # Save weights as alexnet_{epoch}.h5
|
|
|
+ save_weights_only=True,
|
|
|
+ save_freq='epoch', # Save after every epoch
|
|
|
+ verbose=1
|
|
|
+ )
|
|
|
+
|
|
|
+ # Train the model
|
|
|
+ history = model.fit(
|
|
|
+ train_data,
|
|
|
+ epochs=args.epochs,
|
|
|
+ validation_data=val_data,
|
|
|
+ initial_epoch=initial_epoch,
|
|
|
+ callbacks=[csv_logger, checkpoint_callback], # Add checkpoint callback
|
|
|
+ )
|
|
|
+
|
|
|
+ return history
|
|
|
+
|
|
|
+
|
|
|
+def get_args_parser(add_help=True):
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
|
|
|
+
|
|
|
+ parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path")
|
|
|
+ parser.add_argument("--output-dir", default="checkpoints/alexnet", type=str, help="path to save outputs")
|
|
|
+
|
|
|
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
|
|
+ parser.add_argument(
|
|
|
+ "-b", "--batch-size", default=64, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
|
|
+ )
|
|
|
+ parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
|
|
+
|
|
|
+ parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
|
|
+ parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
|
|
|
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
|
|
+ parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
|
|
|
+ parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
|
|
+ parser.add_argument(
|
|
|
+ "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
|
|
|
+ )
|
|
|
+ parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
|
|
+ parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
|
|
+ parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
|
|
+ parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
|
|
|
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
|
|
+
|
|
|
+ parser.add_argument(
|
|
|
+ "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
|
|
+ )
|
|
|
+ return parser
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ args = get_args_parser().parse_args()
|
|
|
+
|
|
|
+ # Set directories for your custom dataset
|
|
|
+ train_dir = os.path.join(args.data_path, "train")
|
|
|
+ val_dir = os.path.join(args.data_path, "val")
|
|
|
+
|
|
|
+ # Set the directory where you want to save weights
|
|
|
+ os.makedirs(args.output_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # Load data
|
|
|
+ train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
|
|
|
+
|
|
|
+ # Start training
|
|
|
+ train_model(args, train_data, val_data)
|