train.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import mindspore
  2. import mindspore.nn as nn
  3. from mindspore.dataset import vision, Cifar10Dataset
  4. from mindspore.dataset.vision import transforms
  5. import tqdm
  6. from tests.model.AlexNet import AlexNet
  7. train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
  8. test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
  9. batch_size = 32
  10. def datapipe(dataset, batch_size):
  11. image_transforms = [
  12. vision.Rescale(1.0 / 255.0, 0),
  13. vision.Normalize(mean=(0.1307,), std=(0.3081,)),
  14. vision.HWC2CHW()
  15. ]
  16. label_transform = transforms.TypeCast(mindspore.int32)
  17. dataset = dataset.map(image_transforms, 'image')
  18. dataset = dataset.map(label_transform, 'label')
  19. dataset = dataset.batch(batch_size)
  20. return dataset
  21. # Map vision transforms and batch dataset
  22. train_dataset = datapipe(train_dataset, batch_size)
  23. test_dataset = datapipe(test_dataset, batch_size)
  24. # Define model
  25. model = AlexNet(input_channels=3, output_num=10, input_size=32)
  26. print(model)
  27. # Instantiate loss function and optimizer
  28. loss_fn = nn.CrossEntropyLoss()
  29. optimizer = nn.Adam(model.trainable_params(), 1e-2)
  30. # 1. Define forward function
  31. def forward_fn(data, label):
  32. logits = model(data)
  33. loss = loss_fn(logits, label)
  34. return loss, logits
  35. # 2. Get gradient function
  36. grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
  37. # 3. Define function of one-step training
  38. def train_step(data, label):
  39. (loss, _), grads = grad_fn(data, label)
  40. optimizer(grads)
  41. return loss
  42. def train(model, dataset):
  43. num_batches = dataset.get_dataset_size()
  44. tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch')
  45. model.set_train()
  46. for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
  47. loss = train_step(data, label)
  48. tqdm_show.set_postfix({'train_loss': loss.asnumpy()}) # 添加显示
  49. tqdm_show.update(1) # 更新进度条
  50. tqdm_show.close()
  51. def test(model, dataset, loss_fn):
  52. num_batches = dataset.get_dataset_size()
  53. model.set_train(False)
  54. total, test_loss, correct = 0, 0, 0
  55. tqdm_show = tqdm.tqdm(total=num_batches, desc='test', unit='batch')
  56. for data, label in dataset.create_tuple_iterator():
  57. pred = model(data)
  58. total += len(data)
  59. loss = loss_fn(pred, label).asnumpy()
  60. tqdm_show.set_postfix({'val_loss': loss}) # 添加显示
  61. tqdm_show.update(1) # 更新进度条
  62. test_loss += loss
  63. correct += (pred.argmax(1) == label).asnumpy().sum()
  64. test_loss /= num_batches
  65. correct /= total
  66. tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  67. tqdm_show.close()
  68. if __name__ == '__main__':
  69. epochs = 10
  70. mindspore.set_context(device_target="GPU")
  71. for t in range(epochs):
  72. print(f"Epoch {t + 1}\n-------------------------------")
  73. train(model, train_dataset)
  74. test(model, test_dataset, loss_fn)
  75. print("Done!")