Quellcode durchsuchen

修改水印提取过程,修改训练代码,新增水印验证代码

liyan vor 11 Monaten
Ursprung
Commit
a82749a139
3 geänderte Dateien mit 69 neuen und 5 gelöschten Zeilen
  1. 8 3
      tests/train.py
  2. 58 0
      tests/val.py
  3. 3 2
      watermark_codec/model_decoder.py

+ 8 - 3
tests/train.py

@@ -8,6 +8,7 @@ from tests.model.AlexNet import AlexNet
 from tests.secret_func import get_secret
 from watermark_codec import ModelEncoder
 
+mindspore.set_context(device_target="CPU", max_device_memory="4GB")
 train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
 test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
 batch_size = 32
@@ -88,7 +89,7 @@ def test(model, dataset, loss_fn):
         total += len(data)
         loss_embed = model_encoder.get_embeder_loss().asnumpy()
         loss = loss_fn(pred, label).asnumpy() + loss_embed
-        tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed.asnumpy()})  # 添加显示
+        tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed})  # 添加显示
         tqdm_show.update(1)  # 更新进度条
         test_loss += loss
         correct += (pred.argmax(1) == label).asnumpy().sum()
@@ -97,12 +98,16 @@ def test(model, dataset, loss_fn):
     tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
     tqdm_show.close()
 
+def save(model, save_path):
+    mindspore.save_checkpoint(model, save_path)
+
 
 if __name__ == '__main__':
-    epochs = 10
-    mindspore.set_context(device_target="GPU")
+    epochs = 50
+    save_path = "./run/train/AlexNet.ckpt"
     for t in range(epochs):
         print(f"Epoch {t + 1}\n-------------------------------")
         train(model, train_dataset)
         test(model, test_dataset, loss_fn)
+        save(model, save_path)
     print("Done!")

+ 58 - 0
tests/val.py

@@ -0,0 +1,58 @@
+import mindspore
+import mindspore.nn as nn
+from mindspore.dataset import vision, Cifar10Dataset
+from mindspore.dataset.vision import transforms
+import tqdm
+
+from tests.model.AlexNet import AlexNet
+from tests.secret_func import verify_secret
+from watermark_codec import ModelDecoder
+
+mindspore.set_context(device_target="CPU", max_device_memory="4GB")
+train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
+test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
+batch_size = 32
+key_path = './run/train/key.ckpt'
+save_path = './run/train/AlexNet.ckpt'
+
+
+def datapipe(dataset, batch_size):
+    image_transforms = [
+        vision.Rescale(1.0 / 255.0, 0),
+        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
+        vision.HWC2CHW()
+    ]
+    label_transform = transforms.TypeCast(mindspore.int32)
+
+    dataset = dataset.map(image_transforms, 'image')
+    dataset = dataset.map(label_transform, 'label')
+    dataset = dataset.batch(batch_size)
+    return dataset
+
+
+# Map vision transforms and batch dataset
+train_dataset = datapipe(train_dataset, batch_size)
+test_dataset = datapipe(test_dataset, batch_size)
+
+# Define model
+model = AlexNet(input_channels=3, output_num=10, input_size=32)
+print(model)
+
+# load model from checkpoint
+param_dict = mindspore.load_checkpoint(save_path)
+param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
+
+# init white_box watermark decoder
+layers = []
+for name, layer in model.cells_and_names():
+    if isinstance(layer, nn.Conv2d):
+        layers.append(layer)
+
+# init model decoder
+model_decoder = ModelDecoder(layers=layers[0:2], key_path=key_path)
+
+secret = model_decoder.decode()
+print(f"secret: {secret}")
+
+result = verify_secret(secret)
+print(f"result: {result}")

+ 3 - 2
watermark_codec/model_decoder.py

@@ -7,7 +7,7 @@ Created on 2024/5/8
 """
 from typing import List
 
-import mindspore as ms
+import numpy as np
 from mindspore import nn
 
 from watermark_codec.tool.str_convertor import bin2string
@@ -27,7 +27,8 @@ class ModelDecoder:
 
     def decode(self):
         prob = get_prob(self.x_random, self.w)
-        decode = ms.ops.where(prob > 0.5, 1, 0)
+        prob = prob.asnumpy()
+        decode = np.where(prob > 0.5, 1, 0)
         code_string = ''.join([str(x) for x in decode.tolist()])
         code_string = bin2string(code_string)
         return code_string