瀏覽代碼

优化训练效率

liyan 7 月之前
父節點
當前提交
8610a0f609
共有 1 個文件被更改,包括 15 次插入7 次删除
  1. 15 7
      train_alexnet.py

+ 15 - 7
train_alexnet.py

@@ -10,12 +10,20 @@ from tensorflow.keras.preprocessing import image_dataset_from_directory
 
 def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     # Define data augmentation for the training set
-    train_datagen = tf.keras.Sequential([
-        tf.keras.layers.RandomFlip('horizontal'),
-        tf.keras.layers.RandomRotation(0.2),
-        tf.keras.layers.RandomZoom(0.2),
-        tf.keras.layers.RandomContrast(0.2),
-    ])
+    # train_datagen = tf.keras.Sequential([
+    #     tf.keras.layers.RandomFlip('horizontal'),
+    #     tf.keras.layers.RandomRotation(0.2),
+    #     tf.keras.layers.RandomZoom(0.2),
+    #     tf.keras.layers.RandomContrast(0.2),
+    # ])
+    def augment(image):
+        # Random horizontal flip
+        image = tf.image.random_flip_left_right(image)
+        # Random contrast adjustment
+        image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
+        # Random brightness adjustment
+        image = tf.image.random_brightness(image, max_delta=0.2)
+        return image
 
     # Load training dataset
     train_dataset = image_dataset_from_directory(
@@ -37,7 +45,7 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
 
     # Normalize the datasets (rescale pixel values to [0, 1])
     train_dataset = train_dataset.map(
-        lambda x, y: (train_datagen(x) / 255.0, y),
+        lambda x, y: (augment(x) / 255.0, y),
         num_parallel_calls=tf.data.AUTOTUNE
     )