瀏覽代碼

训练过程更新进度条显示

liyan 11 月之前
父節點
當前提交
2c73f94121
共有 1 個文件被更改,包括 19 次插入8 次删除
  1. 19 8
      tests/train.py

+ 19 - 8
tests/train.py

@@ -2,6 +2,7 @@ import mindspore
 import mindspore.nn as nn
 from mindspore.dataset import vision, Cifar10Dataset
 from mindspore.dataset.vision import transforms
+import tqdm
 
 from tests.model.AlexNet import AlexNet
 
@@ -22,11 +23,11 @@ def datapipe(dataset, batch_size):
     dataset = dataset.batch(batch_size)
     return dataset
 
+
 # Map vision transforms and batch dataset
 train_dataset = datapipe(train_dataset, batch_size)
 test_dataset = datapipe(test_dataset, batch_size)
 
-
 # Define model
 model = AlexNet(input_channels=3, output_num=10, input_size=32)
 print(model)
@@ -35,43 +36,53 @@ print(model)
 loss_fn = nn.CrossEntropyLoss()
 optimizer = nn.Adam(model.trainable_params(), 1e-2)
 
+
 # 1. Define forward function
 def forward_fn(data, label):
     logits = model(data)
     loss = loss_fn(logits, label)
     return loss, logits
 
+
 # 2. Get gradient function
 grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
 
+
 # 3. Define function of one-step training
 def train_step(data, label):
     (loss, _), grads = grad_fn(data, label)
     optimizer(grads)
     return loss
 
+
 def train(model, dataset):
-    size = dataset.get_dataset_size()
+    num_batches = dataset.get_dataset_size()
+    tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch')
     model.set_train()
     for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
         loss = train_step(data, label)
+        tqdm_show.set_postfix({'train_loss': loss.asnumpy()})  # 添加显示
+        tqdm_show.update(1)  # 更新进度条
+    tqdm_show.close()
 
-        if batch % 100 == 0:
-            loss, current = loss.asnumpy(), batch
-            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
 
 def test(model, dataset, loss_fn):
     num_batches = dataset.get_dataset_size()
     model.set_train(False)
     total, test_loss, correct = 0, 0, 0
+    tqdm_show = tqdm.tqdm(total=num_batches, desc='test', unit='batch')
     for data, label in dataset.create_tuple_iterator():
         pred = model(data)
         total += len(data)
-        test_loss += loss_fn(pred, label).asnumpy()
+        loss = loss_fn(pred, label).asnumpy()
+        tqdm_show.set_postfix({'val_loss': loss})  # 添加显示
+        tqdm_show.update(1)  # 更新进度条
+        test_loss += loss
         correct += (pred.argmax(1) == label).asnumpy().sum()
     test_loss /= num_batches
     correct /= total
-    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
+    tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
+    tqdm_show.close()
 
 
 if __name__ == '__main__':
@@ -81,4 +92,4 @@ if __name__ == '__main__':
         print(f"Epoch {t + 1}\n-------------------------------")
         train(model, train_dataset)
         test(model, test_dataset, loss_fn)
-    print("Done!")
+    print("Done!")