import mindspore import mindspore.nn as nn from mindspore.dataset import vision, Cifar10Dataset from mindspore.dataset.vision import transforms import tqdm from tests.model.AlexNet import AlexNet from tests.secret_func import verify_secret from watermark_codec import ModelDecoder mindspore.set_context(device_target="CPU", max_device_memory="4GB") train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True) test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test') batch_size = 32 key_path = './run/train/key.ckpt' save_path = './run/train/AlexNet.ckpt' def datapipe(dataset, batch_size): image_transforms = [ vision.Rescale(1.0 / 255.0, 0), vision.Normalize(mean=(0.1307,), std=(0.3081,)), vision.HWC2CHW() ] label_transform = transforms.TypeCast(mindspore.int32) dataset = dataset.map(image_transforms, 'image') dataset = dataset.map(label_transform, 'label') dataset = dataset.batch(batch_size) return dataset # Map vision transforms and batch dataset train_dataset = datapipe(train_dataset, batch_size) test_dataset = datapipe(test_dataset, batch_size) # Define model model = AlexNet(input_channels=3, output_num=10, input_size=32) print(model) # load model from checkpoint param_dict = mindspore.load_checkpoint(save_path) param_not_load, _ = mindspore.load_param_into_net(model, param_dict) # init white_box watermark decoder layers = [] for name, layer in model.cells_and_names(): if isinstance(layer, nn.Conv2d): layers.append(layer) # init model decoder model_decoder = ModelDecoder(layers=layers[0:2], key_path=key_path) secret = model_decoder.decode() print(f"secret: {secret}") result = verify_secret(secret) print(f"result: {result}")