Переглянути джерело

开发基于tensorflow框架的VGG16模型训练代码

liyan 7 місяців тому
батько
коміт
acef196f0c
1 змінених файлів з 124 додано та 0 видалено
  1. 124 0
      train_vgg16.py

+ 124 - 0
train_vgg16.py

@@ -0,0 +1,124 @@
+import os
+
+import tensorflow as tf
+from keras.callbacks import ModelCheckpoint, CSVLogger
+from tensorflow.keras.preprocessing.image import ImageDataGenerator
+from tensorflow.keras.applications import VGG16
+
+
+def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
+    # 使用 ImageDataGenerator 加载图像并进行预处理
+    train_datagen = ImageDataGenerator(
+        rescale=1.0 / 255.0,  # 归一化
+        rotation_range=40,
+        width_shift_range=0.2,
+        height_shift_range=0.2,
+        shear_range=0.2,
+        zoom_range=0.2,
+        horizontal_flip=True,
+        fill_mode='nearest'
+    )
+
+    val_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
+
+    train_generator = train_datagen.flow_from_directory(
+        train_dir,
+        target_size=img_size,  # VGG16 输入大小
+        batch_size=batch_size,
+        class_mode='categorical'  # 多分类任务
+    )
+
+    val_generator = val_datagen.flow_from_directory(
+        val_dir,
+        target_size=img_size,
+        batch_size=batch_size,
+        class_mode='categorical'
+    )
+
+    return train_generator, val_generator
+
+
+def train_model(args, train_generator, val_generator):
+    # Create model
+    model = VGG16(weights=None, include_top=True, input_shape=(224, 224, 3), classes=10)
+
+    # 调整学习率
+    learning_rate = args.lr if args.lr else 1e-2
+
+    # 编译模型
+    model.compile(optimizer=tf.keras.optimizers.Adam(),
+                  loss='categorical_crossentropy',
+                  metrics=['accuracy'])
+
+    # Define CSVLogger to log training history to a CSV file
+    csv_logger = CSVLogger(os.path.join(args.output_dir, 'training_log.csv'), append=True)
+
+    # Define ModelCheckpoint callback to save weights for each epoch
+    checkpoint_callback = ModelCheckpoint(
+        os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'),  # Save weights as vgg16_{epoch}.h5
+        save_weights_only=False,
+        save_freq='epoch',  # Save after every epoch
+        verbose=1
+    )
+
+    # 训练模型
+    history = model.fit(
+        train_generator,
+        steps_per_epoch=train_generator.samples // train_generator.batch_size,
+        epochs=args.epochs,
+        validation_data=val_generator,
+        validation_steps=val_generator.samples // val_generator.batch_size,
+        callbacks=[csv_logger, checkpoint_callback]
+    )
+
+    return history
+
+
+def get_args_parser(add_help=True):
+    import argparse
+
+    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
+
+    parser.add_argument("--data-path", default="dataset/imagenette2-320", type=str, help="dataset path")
+    parser.add_argument("--output-dir", default="checkpoints/alexnet", type=str, help="path to save outputs")
+
+    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
+    parser.add_argument(
+        "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
+    )
+    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
+
+    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
+    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
+    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
+    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
+    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
+    parser.add_argument(
+        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
+    )
+    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
+    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
+    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
+    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
+    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
+
+    parser.add_argument(
+        "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
+    )
+    return parser
+
+if __name__ == "__main__":
+    args = get_args_parser().parse_args()
+
+    # Set directories for your custom dataset
+    train_dir = os.path.join(args.data_path, "train")
+    val_dir = os.path.join(args.data_path, "val")
+
+    # Set the directory where you want to save weights
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    # Load data
+    train_generator, val_generator = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
+
+    # Start training
+    train_model(args, train_generator, val_generator)