|
@@ -2,6 +2,7 @@ import mindspore
|
|
import mindspore.nn as nn
|
|
import mindspore.nn as nn
|
|
from mindspore.dataset import vision, Cifar10Dataset
|
|
from mindspore.dataset import vision, Cifar10Dataset
|
|
from mindspore.dataset.vision import transforms
|
|
from mindspore.dataset.vision import transforms
|
|
|
|
+import tqdm
|
|
|
|
|
|
from tests.model.AlexNet import AlexNet
|
|
from tests.model.AlexNet import AlexNet
|
|
|
|
|
|
@@ -22,11 +23,11 @@ def datapipe(dataset, batch_size):
|
|
dataset = dataset.batch(batch_size)
|
|
dataset = dataset.batch(batch_size)
|
|
return dataset
|
|
return dataset
|
|
|
|
|
|
|
|
+
|
|
# Map vision transforms and batch dataset
|
|
# Map vision transforms and batch dataset
|
|
train_dataset = datapipe(train_dataset, batch_size)
|
|
train_dataset = datapipe(train_dataset, batch_size)
|
|
test_dataset = datapipe(test_dataset, batch_size)
|
|
test_dataset = datapipe(test_dataset, batch_size)
|
|
|
|
|
|
-
|
|
|
|
# Define model
|
|
# Define model
|
|
model = AlexNet(input_channels=3, output_num=10, input_size=32)
|
|
model = AlexNet(input_channels=3, output_num=10, input_size=32)
|
|
print(model)
|
|
print(model)
|
|
@@ -35,43 +36,53 @@ print(model)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
optimizer = nn.Adam(model.trainable_params(), 1e-2)
|
|
optimizer = nn.Adam(model.trainable_params(), 1e-2)
|
|
|
|
|
|
|
|
+
|
|
# 1. Define forward function
|
|
# 1. Define forward function
|
|
def forward_fn(data, label):
|
|
def forward_fn(data, label):
|
|
logits = model(data)
|
|
logits = model(data)
|
|
loss = loss_fn(logits, label)
|
|
loss = loss_fn(logits, label)
|
|
return loss, logits
|
|
return loss, logits
|
|
|
|
|
|
|
|
+
|
|
# 2. Get gradient function
|
|
# 2. Get gradient function
|
|
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
|
|
|
|
|
|
+
|
|
# 3. Define function of one-step training
|
|
# 3. Define function of one-step training
|
|
def train_step(data, label):
|
|
def train_step(data, label):
|
|
(loss, _), grads = grad_fn(data, label)
|
|
(loss, _), grads = grad_fn(data, label)
|
|
optimizer(grads)
|
|
optimizer(grads)
|
|
return loss
|
|
return loss
|
|
|
|
|
|
|
|
+
|
|
def train(model, dataset):
|
|
def train(model, dataset):
|
|
- size = dataset.get_dataset_size()
|
|
|
|
|
|
+ num_batches = dataset.get_dataset_size()
|
|
|
|
+ tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch')
|
|
model.set_train()
|
|
model.set_train()
|
|
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
|
|
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
|
|
loss = train_step(data, label)
|
|
loss = train_step(data, label)
|
|
|
|
+ tqdm_show.set_postfix({'train_loss': loss.asnumpy()}) # 添加显示
|
|
|
|
+ tqdm_show.update(1) # 更新进度条
|
|
|
|
+ tqdm_show.close()
|
|
|
|
|
|
- if batch % 100 == 0:
|
|
|
|
- loss, current = loss.asnumpy(), batch
|
|
|
|
- print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
|
|
|
|
|
|
|
|
def test(model, dataset, loss_fn):
|
|
def test(model, dataset, loss_fn):
|
|
num_batches = dataset.get_dataset_size()
|
|
num_batches = dataset.get_dataset_size()
|
|
model.set_train(False)
|
|
model.set_train(False)
|
|
total, test_loss, correct = 0, 0, 0
|
|
total, test_loss, correct = 0, 0, 0
|
|
|
|
+ tqdm_show = tqdm.tqdm(total=num_batches, desc='test', unit='batch')
|
|
for data, label in dataset.create_tuple_iterator():
|
|
for data, label in dataset.create_tuple_iterator():
|
|
pred = model(data)
|
|
pred = model(data)
|
|
total += len(data)
|
|
total += len(data)
|
|
- test_loss += loss_fn(pred, label).asnumpy()
|
|
|
|
|
|
+ loss = loss_fn(pred, label).asnumpy()
|
|
|
|
+ tqdm_show.set_postfix({'val_loss': loss}) # 添加显示
|
|
|
|
+ tqdm_show.update(1) # 更新进度条
|
|
|
|
+ test_loss += loss
|
|
correct += (pred.argmax(1) == label).asnumpy().sum()
|
|
correct += (pred.argmax(1) == label).asnumpy().sum()
|
|
test_loss /= num_batches
|
|
test_loss /= num_batches
|
|
correct /= total
|
|
correct /= total
|
|
- print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
|
|
|
|
|
+ tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
|
|
|
+ tqdm_show.close()
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
@@ -81,4 +92,4 @@ if __name__ == '__main__':
|
|
print(f"Epoch {t + 1}\n-------------------------------")
|
|
print(f"Epoch {t + 1}\n-------------------------------")
|
|
train(model, train_dataset)
|
|
train(model, train_dataset)
|
|
test(model, test_dataset, loss_fn)
|
|
test(model, test_dataset, loss_fn)
|
|
- print("Done!")
|
|
|
|
|
|
+ print("Done!")
|