val.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import mindspore
  2. import mindspore.nn as nn
  3. from mindspore.dataset import vision, Cifar10Dataset
  4. from mindspore.dataset.vision import transforms
  5. import tqdm
  6. from tests.model.AlexNet import AlexNet
  7. from tests.secret_func import verify_secret
  8. from watermark_codec import ModelDecoder
  9. mindspore.set_context(device_target="CPU", max_device_memory="4GB")
  10. train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
  11. test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
  12. batch_size = 32
  13. key_path = './run/train/key.ckpt'
  14. save_path = './run/train/AlexNet.ckpt'
  15. def datapipe(dataset, batch_size):
  16. image_transforms = [
  17. vision.Rescale(1.0 / 255.0, 0),
  18. vision.Normalize(mean=(0.1307,), std=(0.3081,)),
  19. vision.HWC2CHW()
  20. ]
  21. label_transform = transforms.TypeCast(mindspore.int32)
  22. dataset = dataset.map(image_transforms, 'image')
  23. dataset = dataset.map(label_transform, 'label')
  24. dataset = dataset.batch(batch_size)
  25. return dataset
  26. # Map vision transforms and batch dataset
  27. train_dataset = datapipe(train_dataset, batch_size)
  28. test_dataset = datapipe(test_dataset, batch_size)
  29. # Define model
  30. model = AlexNet(input_channels=3, output_num=10, input_size=32)
  31. print(model)
  32. # load model from checkpoint
  33. param_dict = mindspore.load_checkpoint(save_path)
  34. param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
  35. # init white_box watermark decoder
  36. layers = []
  37. for name, layer in model.cells_and_names():
  38. if isinstance(layer, nn.Conv2d):
  39. layers.append(layer)
  40. # init model decoder
  41. model_decoder = ModelDecoder(layers=layers[0:2], key_path=key_path)
  42. secret = model_decoder.decode()
  43. print(f"secret: {secret}")
  44. result = verify_secret(secret)
  45. print(f"result: {result}")