소스 검색

vgg16训练添加归一化操作

liyan 7 달 전
부모
커밋
22544bcdf0
1개의 변경된 파일16개의 추가작업 그리고 4개의 파일을 삭제
  1. 16 4
      train_vgg16.py

+ 16 - 4
train_vgg16.py

@@ -8,6 +8,14 @@ from models.VGG16 import create_model
 
 
 
 
 def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
 def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
+    # 给定的均值和标准差
+    mean = tf.constant([0.485, 0.456, 0.406])
+    std = tf.constant([0.229, 0.224, 0.225])
+
+    # 自定义标准化函数
+    def normalize_image(x):
+        return (x - mean) / std  # 标准化
+
     # 使用 ImageDataGenerator 加载图像并进行预处理
     # 使用 ImageDataGenerator 加载图像并进行预处理
     train_datagen = ImageDataGenerator(
     train_datagen = ImageDataGenerator(
         rescale=1.0 / 255.0,  # 归一化
         rescale=1.0 / 255.0,  # 归一化
@@ -17,10 +25,14 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
         shear_range=0.2,
         shear_range=0.2,
         zoom_range=0.2,
         zoom_range=0.2,
         horizontal_flip=True,
         horizontal_flip=True,
-        fill_mode='nearest'
+        fill_mode='nearest',
+        preprocessing_function=normalize_image  # 使用自定义的标准化函数
     )
     )
 
 
-    val_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
+    val_datagen = ImageDataGenerator(
+        rescale=1.0 / 255.0,
+        preprocessing_function=normalize_image  # 使用自定义的标准化函数
+    )
 
 
     train_generator = train_datagen.flow_from_directory(
     train_generator = train_datagen.flow_from_directory(
         train_dir,
         train_dir,
@@ -85,7 +97,7 @@ def train_model(args, train_generator, val_generator):
 
 
     # Define ModelCheckpoint callback to save weights for each epoch
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
     checkpoint_callback = ModelCheckpoint(
-        os.path.join(args.output_dir, 'vgg16_loss_{val_loss:.4f}_{epoch:03d}.h5'),  # Save weights as vgg16_{epoch}.h5
+        os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'),  # Save weights as vgg16_{epoch}.h5
         save_weights_only=False,
         save_weights_only=False,
         monitor='val_loss',  # Monitor the validation loss
         monitor='val_loss',  # Monitor the validation loss
         save_freq='epoch',  # Save after every epoch
         save_freq='epoch',  # Save after every epoch
@@ -120,7 +132,7 @@ def get_args_parser(add_help=True):
     parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
     parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
 
 
     parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
     parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
-    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
+    parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
     parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
     parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
 
 
     parser.add_argument(
     parser.add_argument(