train.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import mindspore
  2. import mindspore.nn as nn
  3. from mindspore.dataset import vision, Cifar10Dataset
  4. from mindspore.dataset.vision import transforms
  5. import tqdm
  6. from tests.model.AlexNet import AlexNet
  7. from tests.secret_func import get_secret
  8. from watermark_codec import ModelEncoder
  9. train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
  10. test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
  11. batch_size = 32
  12. def datapipe(dataset, batch_size):
  13. image_transforms = [
  14. vision.Rescale(1.0 / 255.0, 0),
  15. vision.Normalize(mean=(0.1307,), std=(0.3081,)),
  16. vision.HWC2CHW()
  17. ]
  18. label_transform = transforms.TypeCast(mindspore.int32)
  19. dataset = dataset.map(image_transforms, 'image')
  20. dataset = dataset.map(label_transform, 'label')
  21. dataset = dataset.batch(batch_size)
  22. return dataset
  23. # Map vision transforms and batch dataset
  24. train_dataset = datapipe(train_dataset, batch_size)
  25. test_dataset = datapipe(test_dataset, batch_size)
  26. # Define model
  27. model = AlexNet(input_channels=3, output_num=10, input_size=32)
  28. print(model)
  29. # init white_box watermark encoder
  30. layers = []
  31. secret = get_secret(512)
  32. key_path = './run/train/key.ckpt'
  33. for name, layer in model.cells_and_names():
  34. if isinstance (layer, nn.Conv2d):
  35. layers.append(layer)
  36. model_encoder = ModelEncoder(layers=layers[0:2], secret=secret, key_path=key_path)
  37. # Instantiate loss function and optimizer
  38. loss_fn = nn.CrossEntropyLoss()
  39. optimizer = nn.Adam(model.trainable_params(), 1e-2)
  40. # 1. Define forward function
  41. def forward_fn(data, label):
  42. logits = model(data)
  43. loss_embed = model_encoder.get_embeder_loss()
  44. loss = loss_fn(logits, label) + loss_embed
  45. return loss, loss_embed, logits
  46. # 2. Get gradient function
  47. grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
  48. # 3. Define function of one-step training
  49. def train_step(data, label):
  50. (loss, loss_embed, _), grads = grad_fn(data, label)
  51. optimizer(grads)
  52. return loss, loss_embed
  53. def train(model, dataset):
  54. num_batches = dataset.get_dataset_size()
  55. tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch')
  56. model.set_train()
  57. for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
  58. loss, loss_embed = train_step(data, label)
  59. tqdm_show.set_postfix({'train_loss': loss.asnumpy(), 'wm_loss': loss_embed.asnumpy()}) # 添加显示
  60. tqdm_show.update(1) # 更新进度条
  61. tqdm_show.close()
  62. def test(model, dataset, loss_fn):
  63. num_batches = dataset.get_dataset_size()
  64. model.set_train(False)
  65. total, test_loss, correct = 0, 0, 0
  66. tqdm_show = tqdm.tqdm(total=num_batches, desc='test', unit='batch')
  67. for data, label in dataset.create_tuple_iterator():
  68. pred = model(data)
  69. total += len(data)
  70. loss_embed = model_encoder.get_embeder_loss().asnumpy()
  71. loss = loss_fn(pred, label).asnumpy() + loss_embed
  72. tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed.asnumpy()}) # 添加显示
  73. tqdm_show.update(1) # 更新进度条
  74. test_loss += loss
  75. correct += (pred.argmax(1) == label).asnumpy().sum()
  76. test_loss /= num_batches
  77. correct /= total
  78. tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  79. tqdm_show.close()
  80. if __name__ == '__main__':
  81. epochs = 10
  82. mindspore.set_context(device_target="GPU")
  83. for t in range(epochs):
  84. print(f"Epoch {t + 1}\n-------------------------------")
  85. train(model, train_dataset)
  86. test(model, test_dataset, loss_fn)
  87. print("Done!")