import mindspore import mindspore.nn as nn from mindspore.dataset import vision, Cifar10Dataset from mindspore.dataset.vision import transforms import tqdm from tests.model.AlexNet import AlexNet train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True) test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test') batch_size = 32 def datapipe(dataset, batch_size): image_transforms = [ vision.Rescale(1.0 / 255.0, 0), vision.Normalize(mean=(0.1307,), std=(0.3081,)), vision.HWC2CHW() ] label_transform = transforms.TypeCast(mindspore.int32) dataset = dataset.map(image_transforms, 'image') dataset = dataset.map(label_transform, 'label') dataset = dataset.batch(batch_size) return dataset # Map vision transforms and batch dataset train_dataset = datapipe(train_dataset, batch_size) test_dataset = datapipe(test_dataset, batch_size) # Define model model = AlexNet(input_channels=3, output_num=10, input_size=32) print(model) # Instantiate loss function and optimizer loss_fn = nn.CrossEntropyLoss() optimizer = nn.Adam(model.trainable_params(), 1e-2) # 1. Define forward function def forward_fn(data, label): logits = model(data) loss = loss_fn(logits, label) return loss, logits # 2. Get gradient function grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) # 3. Define function of one-step training def train_step(data, label): (loss, _), grads = grad_fn(data, label) optimizer(grads) return loss def train(model, dataset): num_batches = dataset.get_dataset_size() tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch') model.set_train() for batch, (data, label) in enumerate(dataset.create_tuple_iterator()): loss = train_step(data, label) tqdm_show.set_postfix({'train_loss': loss.asnumpy()}) # 添加显示 tqdm_show.update(1) # 更新进度条 tqdm_show.close() def test(model, dataset, loss_fn): num_batches = dataset.get_dataset_size() model.set_train(False) total, test_loss, correct = 0, 0, 0 tqdm_show = tqdm.tqdm(total=num_batches, desc='test', unit='batch') for data, label in dataset.create_tuple_iterator(): pred = model(data) total += len(data) loss = loss_fn(pred, label).asnumpy() tqdm_show.set_postfix({'val_loss': loss}) # 添加显示 tqdm_show.update(1) # 更新进度条 test_loss += loss correct += (pred.argmax(1) == label).asnumpy().sum() test_loss /= num_batches correct /= total tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") tqdm_show.close() if __name__ == '__main__': epochs = 10 mindspore.set_context(device_target="GPU") for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train(model, train_dataset) test(model, test_dataset, loss_fn) print("Done!")