train.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import mindspore
  2. import mindspore.nn as nn
  3. from mindspore.dataset import vision, Cifar10Dataset
  4. from mindspore.dataset.vision import transforms
  5. from tests.model.AlexNet import AlexNet
  6. train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
  7. test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
  8. batch_size = 32
  9. def datapipe(dataset, batch_size):
  10. image_transforms = [
  11. vision.Rescale(1.0 / 255.0, 0),
  12. vision.Normalize(mean=(0.1307,), std=(0.3081,)),
  13. vision.HWC2CHW()
  14. ]
  15. label_transform = transforms.TypeCast(mindspore.int32)
  16. dataset = dataset.map(image_transforms, 'image')
  17. dataset = dataset.map(label_transform, 'label')
  18. dataset = dataset.batch(batch_size)
  19. return dataset
  20. # Map vision transforms and batch dataset
  21. train_dataset = datapipe(train_dataset, batch_size)
  22. test_dataset = datapipe(test_dataset, batch_size)
  23. # Define model
  24. model = AlexNet(input_channels=3, output_num=10, input_size=32)
  25. print(model)
  26. # Instantiate loss function and optimizer
  27. loss_fn = nn.CrossEntropyLoss()
  28. optimizer = nn.Adam(model.trainable_params(), 1e-2)
  29. # 1. Define forward function
  30. def forward_fn(data, label):
  31. logits = model(data)
  32. loss = loss_fn(logits, label)
  33. return loss, logits
  34. # 2. Get gradient function
  35. grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
  36. # 3. Define function of one-step training
  37. def train_step(data, label):
  38. (loss, _), grads = grad_fn(data, label)
  39. optimizer(grads)
  40. return loss
  41. def train(model, dataset):
  42. size = dataset.get_dataset_size()
  43. model.set_train()
  44. for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
  45. loss = train_step(data, label)
  46. if batch % 100 == 0:
  47. loss, current = loss.asnumpy(), batch
  48. print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
  49. def test(model, dataset, loss_fn):
  50. num_batches = dataset.get_dataset_size()
  51. model.set_train(False)
  52. total, test_loss, correct = 0, 0, 0
  53. for data, label in dataset.create_tuple_iterator():
  54. pred = model(data)
  55. total += len(data)
  56. test_loss += loss_fn(pred, label).asnumpy()
  57. correct += (pred.argmax(1) == label).asnumpy().sum()
  58. test_loss /= num_batches
  59. correct /= total
  60. print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  61. if __name__ == '__main__':
  62. epochs = 10
  63. mindspore.set_context(device_target="GPU")
  64. for t in range(epochs):
  65. print(f"Epoch {t + 1}\n-------------------------------")
  66. train(model, train_dataset)
  67. test(model, test_dataset, loss_fn)
  68. print("Done!")