123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import os
- import tensorflow as tf
- from keras.optimizers import Adam, SGD
- from keras.callbacks import ModelCheckpoint, CSVLogger
- from models.AlexNet import create_model
- from tensorflow.keras.preprocessing import image_dataset_from_directory
- def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
- def augment(image):
- # Random horizontal flip
- image = tf.image.random_flip_left_right(image)
- # Random contrast adjustment
- image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
- # Random brightness adjustment
- image = tf.image.random_brightness(image, max_delta=0.2)
- return image
- # Load training dataset
- train_dataset = image_dataset_from_directory(
- train_dir,
- image_size=img_size,
- batch_size=batch_size,
- label_mode='categorical',
- shuffle=True
- )
- # Load validation dataset
- val_dataset = image_dataset_from_directory(
- val_dir,
- image_size=img_size,
- batch_size=batch_size,
- label_mode='categorical',
- shuffle=False
- )
- # Define mean and std for standardization (ImageNet values)
- mean = tf.constant([0.485, 0.456, 0.406])
- std = tf.constant([0.229, 0.224, 0.225])
- # Normalize and standardize the datasets
- train_dataset = train_dataset.map(
- lambda x, y: ((augment(x) / 255.0 - mean) / std, y),
- num_parallel_calls=tf.data.AUTOTUNE
- )
- val_dataset = val_dataset.map(
- lambda x, y: ((x / 255.0 - mean) / std, y),
- num_parallel_calls=tf.data.AUTOTUNE
- )
- # Prefetch to improve performance
- train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
- val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
- return train_dataset, val_dataset
- def find_latest_checkpoint(directory):
- # 获取指定目录下的所有 .h5 文件
- checkpoint_files = [f for f in os.listdir(directory) if f.endswith('.h5')]
- if not checkpoint_files:
- return None
- # 按照文件名中的数字进行排序,找到最新的 epoch 文件
- checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
- return os.path.join(directory, checkpoint_files[-1])
- def train_model(args, train_data, val_data):
- # Create model
- model = create_model()
- # 调整学习率
- learning_rate = args.lr if args.lr else 1e-2
- # Select optimizer based on args.opt
- if args.opt == 'sgd':
- optimizer = SGD(learning_rate=learning_rate,
- momentum=args.momentum if args.momentum else 0.0)
- elif args.opt == 'adam':
- optimizer = Adam(learning_rate=learning_rate)
- else:
- optimizer = Adam(learning_rate=learning_rate) # Default to Adam if unspecified
- # Compile model
- model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
- # Check if a checkpoint exists and determine the initial_epoch
- latest_checkpoint = find_latest_checkpoint(args.output_dir)
- if latest_checkpoint:
- model.load_weights(latest_checkpoint) # Load the weights from the checkpoint
- initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) # Get the last epoch from filename
- print(f"Resuming training from epoch {initial_epoch}")
- else:
- initial_epoch = 0
- print("No checkpoint found. Starting training from scratch.")
- # Define CSVLogger to log training history to a CSV file
- csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
- # Define ModelCheckpoint callback to save weights for each epoch
- checkpoint_callback = ModelCheckpoint(
- filepath=os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),
- save_weights_only=False,
- save_freq='epoch', # Save after every epoch
- monitor='val_loss', # Monitor the validation loss
- verbose=1
- )
- # Train the model
- history = model.fit(
- train_data,
- epochs=args.epochs,
- validation_data=val_data,
- initial_epoch=initial_epoch,
- callbacks=[csv_logger, checkpoint_callback], # Add checkpoint callback
- )
- return history
- def get_args_parser(add_help=True):
- import argparse
- parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
- parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path")
- parser.add_argument("--output-dir", default="checkpoints/alexnet", type=str, help="path to save outputs")
- parser.add_argument(
- "-b", "--batch-size", default=64, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
- )
- parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
- parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
- parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
- parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
- parser.add_argument(
- "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
- )
- return parser
- if __name__ == "__main__":
- args = get_args_parser().parse_args()
- # Set directories for your custom dataset
- train_dir = os.path.join(args.data_path, "train")
- val_dir = os.path.join(args.data_path, "val")
- # Set the directory where you want to save weights
- os.makedirs(args.output_dir, exist_ok=True)
- # Load data
- train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size),
- batch_size=args.batch_size)
- # Start training
- train_model(args, train_data, val_data)
|