12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import mindspore
- import mindspore.nn as nn
- from mindspore.dataset import vision, Cifar10Dataset
- from mindspore.dataset.vision import transforms
- import tqdm
- from tests.model.AlexNet import AlexNet
- from tests.secret_func import verify_secret
- from watermark_codec import ModelDecoder
- mindspore.set_context(device_target="CPU", max_device_memory="4GB")
- train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
- test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
- batch_size = 32
- key_path = './run/train/key.ckpt'
- save_path = './run/train/AlexNet.ckpt'
- def datapipe(dataset, batch_size):
- image_transforms = [
- vision.Rescale(1.0 / 255.0, 0),
- vision.Normalize(mean=(0.1307,), std=(0.3081,)),
- vision.HWC2CHW()
- ]
- label_transform = transforms.TypeCast(mindspore.int32)
- dataset = dataset.map(image_transforms, 'image')
- dataset = dataset.map(label_transform, 'label')
- dataset = dataset.batch(batch_size)
- return dataset
- # Map vision transforms and batch dataset
- train_dataset = datapipe(train_dataset, batch_size)
- test_dataset = datapipe(test_dataset, batch_size)
- # Define model
- model = AlexNet(input_channels=3, output_num=10, input_size=32)
- print(model)
- # load model from checkpoint
- param_dict = mindspore.load_checkpoint(save_path)
- param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
- # init white_box watermark decoder
- layers = []
- for name, layer in model.cells_and_names():
- if isinstance(layer, nn.Conv2d):
- layers.append(layer)
- # init model decoder
- model_decoder = ModelDecoder(layers=layers[0:2], key_path=key_path)
- secret = model_decoder.decode()
- print(f"secret: {secret}")
- result = verify_secret(secret)
- print(f"result: {result}")
|