Переглянути джерело

解决训练过程中,显卡占用率低问题

liyan 7 місяців тому
батько
коміт
d37a33f948
1 змінених файлів з 6 додано та 0 видалено
  1. 6 0
      train_alexnet.py

+ 6 - 0
train_alexnet.py

@@ -38,12 +38,18 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     # Normalize the datasets (rescale pixel values to [0, 1])
     train_dataset = train_dataset.map(
         lambda x, y: (train_datagen(x) / 255.0, y),
+        num_parallel_calls=tf.data.AUTOTUNE
     )
 
     val_dataset = val_dataset.map(
         lambda x, y: (x / 255.0, y),
+        num_parallel_calls=tf.data.AUTOTUNE
     )
 
+    # Prefetch to improve performance
+    train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
+    val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
+
     return train_dataset, val_dataset