浏览代码

alexnet训练添加归一化操作

liyan 7 月之前
父节点
当前提交
482fda832e
共有 1 个文件被更改,包括 12 次插入15 次删除
  1. 12 15
      train_alexnet.py

+ 12 - 15
train_alexnet.py

@@ -9,13 +9,6 @@ from tensorflow.keras.preprocessing import image_dataset_from_directory
 
 
 def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
-    # Define data augmentation for the training set
-    # train_datagen = tf.keras.Sequential([
-    #     tf.keras.layers.RandomFlip('horizontal'),
-    #     tf.keras.layers.RandomRotation(0.2),
-    #     tf.keras.layers.RandomZoom(0.2),
-    #     tf.keras.layers.RandomContrast(0.2),
-    # ])
     def augment(image):
         # Random horizontal flip
         image = tf.image.random_flip_left_right(image)
@@ -28,29 +21,33 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     # Load training dataset
     train_dataset = image_dataset_from_directory(
         train_dir,
-        image_size=img_size,  # Resize images to (224, 224)
+        image_size=img_size,
         batch_size=batch_size,
-        label_mode='categorical',  # Return integer labels
+        label_mode='categorical',
         shuffle=True
     )
 
     # Load validation dataset
     val_dataset = image_dataset_from_directory(
         val_dir,
-        image_size=img_size,  # Resize images to (224, 224)
+        image_size=img_size,
         batch_size=batch_size,
-        label_mode='categorical',  # Return integer labels
+        label_mode='categorical',
         shuffle=False
     )
 
-    # Normalize the datasets (rescale pixel values to [0, 1])
+    # Define mean and std for standardization (ImageNet values)
+    mean = tf.constant([0.485, 0.456, 0.406])
+    std = tf.constant([0.229, 0.224, 0.225])
+
+    # Normalize and standardize the datasets
     train_dataset = train_dataset.map(
-        lambda x, y: (augment(x) / 255.0, y),
+        lambda x, y: ((augment(x) / 255.0 - mean) / std, y),
         num_parallel_calls=tf.data.AUTOTUNE
     )
 
     val_dataset = val_dataset.map(
-        lambda x, y: (x / 255.0, y),
+        lambda x, y: ((x / 255.0 - mean) / std, y),
         num_parallel_calls=tf.data.AUTOTUNE
     )
 
@@ -105,7 +102,7 @@ def train_model(args, train_data, val_data):
 
     # Define ModelCheckpoint callback to save weights for each epoch
     checkpoint_callback = ModelCheckpoint(
-        filepath=os.path.join(args.output_dir, 'alexnet_loss_{val_loss:.4f}_{epoch:03d}.h5'),
+        filepath=os.path.join(args.output_dir, 'alexnet_{epoch:03d}.h5'),
         save_weights_only=False,
         save_freq='epoch',  # Save after every epoch
         monitor='val_loss',  # Monitor the validation loss