Explorar el Código

解决训练过程中,显卡占用率低问题

liyan hace 7 meses
padre
commit
d37a33f948
Se han modificado 1 ficheros con 6 adiciones y 0 borrados
  1. 6 0
      train_alexnet.py

+ 6 - 0
train_alexnet.py

@@ -38,12 +38,18 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     # Normalize the datasets (rescale pixel values to [0, 1])
     # Normalize the datasets (rescale pixel values to [0, 1])
     train_dataset = train_dataset.map(
     train_dataset = train_dataset.map(
         lambda x, y: (train_datagen(x) / 255.0, y),
         lambda x, y: (train_datagen(x) / 255.0, y),
+        num_parallel_calls=tf.data.AUTOTUNE
     )
     )
 
 
     val_dataset = val_dataset.map(
     val_dataset = val_dataset.map(
         lambda x, y: (x / 255.0, y),
         lambda x, y: (x / 255.0, y),
+        num_parallel_calls=tf.data.AUTOTUNE
     )
     )
 
 
+    # Prefetch to improve performance
+    train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
+    val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
+
     return train_dataset, val_dataset
     return train_dataset, val_dataset