|
@@ -0,0 +1,58 @@
|
|
|
+import mindspore
|
|
|
+import mindspore.nn as nn
|
|
|
+from mindspore.dataset import vision, Cifar10Dataset
|
|
|
+from mindspore.dataset.vision import transforms
|
|
|
+import tqdm
|
|
|
+
|
|
|
+from tests.model.AlexNet import AlexNet
|
|
|
+from tests.secret_func import verify_secret
|
|
|
+from watermark_codec import ModelDecoder
|
|
|
+
|
|
|
+mindspore.set_context(device_target="CPU", max_device_memory="4GB")
|
|
|
+train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
|
|
|
+test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
|
|
|
+batch_size = 32
|
|
|
+key_path = './run/train/key.ckpt'
|
|
|
+save_path = './run/train/AlexNet.ckpt'
|
|
|
+
|
|
|
+
|
|
|
+def datapipe(dataset, batch_size):
|
|
|
+ image_transforms = [
|
|
|
+ vision.Rescale(1.0 / 255.0, 0),
|
|
|
+ vision.Normalize(mean=(0.1307,), std=(0.3081,)),
|
|
|
+ vision.HWC2CHW()
|
|
|
+ ]
|
|
|
+ label_transform = transforms.TypeCast(mindspore.int32)
|
|
|
+
|
|
|
+ dataset = dataset.map(image_transforms, 'image')
|
|
|
+ dataset = dataset.map(label_transform, 'label')
|
|
|
+ dataset = dataset.batch(batch_size)
|
|
|
+ return dataset
|
|
|
+
|
|
|
+
|
|
|
+# Map vision transforms and batch dataset
|
|
|
+train_dataset = datapipe(train_dataset, batch_size)
|
|
|
+test_dataset = datapipe(test_dataset, batch_size)
|
|
|
+
|
|
|
+# Define model
|
|
|
+model = AlexNet(input_channels=3, output_num=10, input_size=32)
|
|
|
+print(model)
|
|
|
+
|
|
|
+# load model from checkpoint
|
|
|
+param_dict = mindspore.load_checkpoint(save_path)
|
|
|
+param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
|
|
|
+
|
|
|
+# init white_box watermark decoder
|
|
|
+layers = []
|
|
|
+for name, layer in model.cells_and_names():
|
|
|
+ if isinstance(layer, nn.Conv2d):
|
|
|
+ layers.append(layer)
|
|
|
+
|
|
|
+# init model decoder
|
|
|
+model_decoder = ModelDecoder(layers=layers[0:2], key_path=key_path)
|
|
|
+
|
|
|
+secret = model_decoder.decode()
|
|
|
+print(f"secret: {secret}")
|
|
|
+
|
|
|
+result = verify_secret(secret)
|
|
|
+print(f"result: {result}")
|