|
@@ -0,0 +1,84 @@
|
|
|
|
+import mindspore
|
|
|
|
+import mindspore.nn as nn
|
|
|
|
+from mindspore.dataset import vision, Cifar10Dataset
|
|
|
|
+from mindspore.dataset.vision import transforms
|
|
|
|
+
|
|
|
|
+from tests.model.AlexNet import AlexNet
|
|
|
|
+
|
|
|
|
+train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
|
|
|
|
+test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
|
|
|
|
+batch_size = 32
|
|
|
|
+
|
|
|
|
+def datapipe(dataset, batch_size):
|
|
|
|
+ image_transforms = [
|
|
|
|
+ vision.Rescale(1.0 / 255.0, 0),
|
|
|
|
+ vision.Normalize(mean=(0.1307,), std=(0.3081,)),
|
|
|
|
+ vision.HWC2CHW()
|
|
|
|
+ ]
|
|
|
|
+ label_transform = transforms.TypeCast(mindspore.int32)
|
|
|
|
+
|
|
|
|
+ dataset = dataset.map(image_transforms, 'image')
|
|
|
|
+ dataset = dataset.map(label_transform, 'label')
|
|
|
|
+ dataset = dataset.batch(batch_size)
|
|
|
|
+ return dataset
|
|
|
|
+
|
|
|
|
+# Map vision transforms and batch dataset
|
|
|
|
+train_dataset = datapipe(train_dataset, batch_size)
|
|
|
|
+test_dataset = datapipe(test_dataset, batch_size)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Define model
|
|
|
|
+model = AlexNet(input_channels=3, output_num=10, input_size=32)
|
|
|
|
+print(model)
|
|
|
|
+
|
|
|
|
+# Instantiate loss function and optimizer
|
|
|
|
+loss_fn = nn.CrossEntropyLoss()
|
|
|
|
+optimizer = nn.Adam(model.trainable_params(), 1e-2)
|
|
|
|
+
|
|
|
|
+# 1. Define forward function
|
|
|
|
+def forward_fn(data, label):
|
|
|
|
+ logits = model(data)
|
|
|
|
+ loss = loss_fn(logits, label)
|
|
|
|
+ return loss, logits
|
|
|
|
+
|
|
|
|
+# 2. Get gradient function
|
|
|
|
+grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
|
|
+
|
|
|
|
+# 3. Define function of one-step training
|
|
|
|
+def train_step(data, label):
|
|
|
|
+ (loss, _), grads = grad_fn(data, label)
|
|
|
|
+ optimizer(grads)
|
|
|
|
+ return loss
|
|
|
|
+
|
|
|
|
+def train(model, dataset):
|
|
|
|
+ size = dataset.get_dataset_size()
|
|
|
|
+ model.set_train()
|
|
|
|
+ for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
|
|
|
|
+ loss = train_step(data, label)
|
|
|
|
+
|
|
|
|
+ if batch % 100 == 0:
|
|
|
|
+ loss, current = loss.asnumpy(), batch
|
|
|
|
+ print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
|
|
|
|
+
|
|
|
|
+def test(model, dataset, loss_fn):
|
|
|
|
+ num_batches = dataset.get_dataset_size()
|
|
|
|
+ model.set_train(False)
|
|
|
|
+ total, test_loss, correct = 0, 0, 0
|
|
|
|
+ for data, label in dataset.create_tuple_iterator():
|
|
|
|
+ pred = model(data)
|
|
|
|
+ total += len(data)
|
|
|
|
+ test_loss += loss_fn(pred, label).asnumpy()
|
|
|
|
+ correct += (pred.argmax(1) == label).asnumpy().sum()
|
|
|
|
+ test_loss /= num_batches
|
|
|
|
+ correct /= total
|
|
|
|
+ print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ epochs = 10
|
|
|
|
+ mindspore.set_context(device_target="GPU")
|
|
|
|
+ for t in range(epochs):
|
|
|
|
+ print(f"Epoch {t + 1}\n-------------------------------")
|
|
|
|
+ train(model, train_dataset)
|
|
|
|
+ test(model, test_dataset, loss_fn)
|
|
|
|
+ print("Done!")
|