Browse Source

新增白盒水印验证流程代码

liyan 1 year ago
parent
commit
0922c7b627
3 changed files with 165 additions and 1 deletions
  1. 36 0
      block/secret_get.py
  2. 126 0
      predict_pt_embed.py
  3. 3 1
      train_embed.py

File diff suppressed because it is too large
+ 36 - 0
block/secret_get.py


+ 126 - 0
predict_pt_embed.py

@@ -0,0 +1,126 @@
+import os
+import time
+
+import numpy as np
+import torch
+import argparse
+from PIL import Image
+from torch import nn
+from torchvision import transforms
+from watermark_codec import ModelDecoder
+
+from block import secret_get
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|pt模型推理|')
+parser.add_argument('--model_path', default='./checkpoints/Alexnet/wm_embed/best.pt', type=str, help='|pt模型位置|')
+parser.add_argument('--key_path', default='./checkpoints/Alexnet/wm_embed/key.pt', type=str, help='|投影矩阵位置|')
+parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
+parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
+parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
+parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
+parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
+assert os.path.exists(args.key_path), f'! key_path:{args.key_path} !'
+assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
+if args.float16:
+    assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def predict_pt(args):
+    # 加载模型
+    model_dict = torch.load(args.model_path, map_location='cpu')
+    model = model_dict['model']
+    model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
+    epoch = model_dict['epoch_finished']
+    accuracy = round(model_dict['standard'], 4)
+    print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | accuracy:{accuracy} |')
+
+    # 选择加密层并初始化白盒水印编码器
+    conv_list = []
+    for module in model.modules():
+        if isinstance(module, nn.Conv2d):
+            conv_list.append(module)
+    conv_list = conv_list[0:2]
+    decoder = ModelDecoder(layers=conv_list, key_path=args.key_path, device=args.device)  # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
+    secret_extract = decoder.decode()  # 提取密码标签
+    result = secret_get.verify_secret(secret_extract)
+    print(f"白盒水印验证结果: {result}, 提取的密码标签为: {secret_extract}")
+
+    # 推理
+    start_time = time.time()
+    with torch.no_grad():
+        print(f"加载测试集至内存...")
+        transform = transforms.Compose([
+            transforms.ToTensor(),  # 将图像转换为PyTorch张量
+            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
+        ])
+        dataset = CustomDataset(data_dir=args.data_path, transform=transform)
+        dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
+                                                 shuffle=False, drop_last=False, pin_memory=False,
+                                                 num_workers=args.num_worker)
+        print(f"加载测试集完成,开始预测...")
+        correct = 0
+        total = 0
+        epoch = 0
+        for index, (image_batch, true_batch) in enumerate(dataloader):
+            image_batch = image_batch.to(args.device)
+            pred_batch = model(image_batch).detach().cpu()
+            # 获取指标项
+            _, predicted = torch.max(pred_batch, 1)
+            total += true_batch.size(0)
+            correct += (predicted == true_batch).sum().item()
+            epoch = epoch + 1
+        # 计算指标
+        accuracy = correct / total
+        end_time = time.time()
+        print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
+
+
+class CustomDataset(torch.utils.data.Dataset):
+    """
+    自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别
+    """
+
+    def __init__(self, data_dir, image_size=(32, 32), transform=None):
+        self.data_dir = data_dir
+        self.image_size = image_size
+        self.transform = transform
+
+        self.images = []
+        self.labels = []
+
+        # 遍历指定目录下的子目录,每个子目录代表一个类别
+        class_dirs = sorted(os.listdir(data_dir))
+        for index, class_dir in enumerate(class_dirs):
+            class_path = os.path.join(data_dir, class_dir)
+
+            # 遍历当前类别目录下的图像文件
+            for image_file in os.listdir(class_path):
+                image_path = os.path.join(class_path, image_file)
+
+                # 使用PIL加载图像并调整大小
+                image = Image.open(image_path).convert('RGB')
+                image = image.resize(image_size)
+
+                self.images.append(np.array(image))
+                self.labels.append(index)
+
+    def __len__(self):
+        return len(self.images)
+
+    def __getitem__(self, idx):
+        image = self.images[idx]
+        label = self.labels[idx]
+
+        if self.transform:
+            image = self.transform(Image.fromarray(image))
+
+        return image, label
+
+
+if __name__ == '__main__':
+    predict_pt(args)

File diff suppressed because it is too large
+ 3 - 1
train_embed.py