瀏覽代碼

解决训练过程中,显卡占用率低问题

liyan 7 月之前
父節點
當前提交
d37a33f948
共有 1 個文件被更改,包括 6 次插入0 次删除
  1. 6 0
      train_alexnet.py

+ 6 - 0
train_alexnet.py

@@ -38,12 +38,18 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     # Normalize the datasets (rescale pixel values to [0, 1])
     train_dataset = train_dataset.map(
         lambda x, y: (train_datagen(x) / 255.0, y),
+        num_parallel_calls=tf.data.AUTOTUNE
     )
 
     val_dataset = val_dataset.map(
         lambda x, y: (x / 255.0, y),
+        num_parallel_calls=tf.data.AUTOTUNE
     )
 
+    # Prefetch to improve performance
+    train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
+    val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
+
     return train_dataset, val_dataset