train.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import mindspore
  2. import mindspore.nn as nn
  3. from mindspore.dataset import vision, Cifar10Dataset
  4. from mindspore.dataset.vision import transforms
  5. import tqdm
  6. from tests.model.AlexNet import AlexNet
  7. from tests.secret_func import get_secret
  8. from watermark_codec import ModelEncoder
  9. mindspore.set_context(device_target="CPU", max_device_memory="4GB")
  10. train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
  11. test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
  12. batch_size = 32
  13. def datapipe(dataset, batch_size):
  14. image_transforms = [
  15. vision.Rescale(1.0 / 255.0, 0),
  16. vision.Normalize(mean=(0.1307,), std=(0.3081,)),
  17. vision.HWC2CHW()
  18. ]
  19. label_transform = transforms.TypeCast(mindspore.int32)
  20. dataset = dataset.map(image_transforms, 'image')
  21. dataset = dataset.map(label_transform, 'label')
  22. dataset = dataset.batch(batch_size)
  23. return dataset
  24. # Map vision transforms and batch dataset
  25. train_dataset = datapipe(train_dataset, batch_size)
  26. test_dataset = datapipe(test_dataset, batch_size)
  27. # Define model
  28. model = AlexNet(input_channels=3, output_num=10, input_size=32)
  29. print(model)
  30. # init white_box watermark encoder
  31. layers = []
  32. secret = get_secret(512)
  33. key_path = './run/train/key.ckpt'
  34. for name, layer in model.cells_and_names():
  35. if isinstance (layer, nn.Conv2d):
  36. layers.append(layer)
  37. model_encoder = ModelEncoder(layers=layers[0:2], secret=secret, key_path=key_path)
  38. # Instantiate loss function and optimizer
  39. loss_fn = nn.CrossEntropyLoss()
  40. optimizer = nn.Adam(model.trainable_params(), 1e-2)
  41. # 1. Define forward function
  42. def forward_fn(data, label):
  43. logits = model(data)
  44. loss_embed = model_encoder.get_embeder_loss()
  45. loss = loss_fn(logits, label) + loss_embed
  46. return loss, loss_embed, logits
  47. # 2. Get gradient function
  48. grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
  49. # 3. Define function of one-step training
  50. def train_step(data, label):
  51. (loss, loss_embed, _), grads = grad_fn(data, label)
  52. optimizer(grads)
  53. return loss, loss_embed
  54. def train(model, dataset):
  55. num_batches = dataset.get_dataset_size()
  56. tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch')
  57. model.set_train()
  58. for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
  59. loss, loss_embed = train_step(data, label)
  60. tqdm_show.set_postfix({'train_loss': loss.asnumpy(), 'wm_loss': loss_embed.asnumpy()}) # 添加显示
  61. tqdm_show.update(1) # 更新进度条
  62. tqdm_show.close()
  63. def test(model, dataset, loss_fn):
  64. num_batches = dataset.get_dataset_size()
  65. model.set_train(False)
  66. total, test_loss, correct = 0, 0, 0
  67. tqdm_show = tqdm.tqdm(total=num_batches, desc='test', unit='batch')
  68. for data, label in dataset.create_tuple_iterator():
  69. pred = model(data)
  70. total += len(data)
  71. loss_embed = model_encoder.get_embeder_loss().asnumpy()
  72. loss = loss_fn(pred, label).asnumpy() + loss_embed
  73. tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed}) # 添加显示
  74. tqdm_show.update(1) # 更新进度条
  75. test_loss += loss
  76. correct += (pred.argmax(1) == label).asnumpy().sum()
  77. test_loss /= num_batches
  78. correct /= total
  79. tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  80. tqdm_show.close()
  81. def save(model, save_path):
  82. mindspore.save_checkpoint(model, save_path)
  83. def export_onnx(model):
  84. mindspore.export(model, train_dataset, file_name='./run/train/AlexNet.onnx', file_format='ONNX')
  85. if __name__ == '__main__':
  86. epochs = 50
  87. save_path = "./run/train/AlexNet.ckpt"
  88. for t in range(epochs):
  89. print(f"Epoch {t + 1}\n-------------------------------")
  90. train(model, train_dataset)
  91. test(model, test_dataset, loss_fn)
  92. save(model, save_path)
  93. export_onnx(model)
  94. print("Done!")