Преглед изворни кода

添加白盒水印验证脚本

liyan пре 1 година
родитељ
комит
cc61fce63b
2 измењених фајлова са 100 додато и 0 уклоњено
  1. 1 0
      README.md
  2. 99 0
      predict_watermark.py

+ 1 - 0
README.md

@@ -89,6 +89,7 @@
 │   └── yolov7_cls.py
 ├── predict_onnx.py #onnx格式模型文件推理
 ├── predict_pt.py #pt格式模型文件推理
+├── predict_watermark.py #加载pt格式模型文件,并且验证白盒水印推理
 ├── predict_trt.py
 ├── prune_last.pt
 ├── requirement

+ 99 - 0
predict_watermark.py

@@ -0,0 +1,99 @@
+"""
+验证白盒水印提取效果
+"""
+import os
+import cv2
+import time
+import torch
+import argparse
+import albumentations
+from tool import secret_func
+from tool.training_embedding import Embedding
+from model.layer import deploy
+
+# -------------------------------------------------------------------------------------------------------------------- #
+parser = argparse.ArgumentParser(description='|pt模型白盒水印提取|')
+parser.add_argument('--model_path', default='best.pt', type=str, help='|pt模型位置|')
+parser.add_argument('--key_path', default='./checkpoints/Alexnet/white_box_embed/x_random.pt', type=str,
+                    help='|白盒模型投影矩阵位置|')
+parser.add_argument('--data_path',
+                    default='/home/yhsun/classification-main/dataset/CIFAR-10/train_cifar10_JPG/airplane', type=str,
+                    help='|图片文件夹位置|')
+parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
+parser.add_argument('--normalization', default='sigmoid', type=str,
+                    help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
+parser.add_argument('--batch', default=1, type=int, help='|输入图片批量|')
+parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
+parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
+parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
+args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
+# -------------------------------------------------------------------------------------------------------------------- #
+assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
+assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
+if args.float16:
+    assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
+
+
+# -------------------------------------------------------------------------------------------------------------------- #
+def predict_pt(args):
+    # 加载模型
+    model_dict = torch.load(args.model_path, map_location='cpu')
+    model = model_dict['model']
+
+    # 初始化白盒水印编码器
+    # key_path = './checkpoints/Alexnet/white_box_embed/x_random.pt'  # 保存投影矩阵位置
+    embeder = Embedding(model=model.to(args.device), code='', key_path=args.key_path, train=False, device=args.device)
+    code = embeder.test()
+    print(f'code:{code}')
+    if secret_func.verify_secret(code):
+        print('模型水印验证成功')
+    else:
+        print('模型水印验证失败')
+
+    # 检测模型预测指标
+    model = deploy(model, args.normalization)
+    model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
+    epoch = model_dict['epoch_finished']
+    m_ap = round(model_dict['standard'], 4)
+    print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | m_ap:{m_ap}|')
+    # 推理
+    image_dir = sorted(os.listdir(args.data_path))
+    start_time = time.time()
+    with torch.no_grad():
+        dataloader = torch.utils.data.DataLoader(torch_dataset(image_dir), batch_size=args.batch,
+                                                 shuffle=False, drop_last=False, pin_memory=False,
+                                                 num_workers=args.num_worker)
+        result = []
+        for item, batch in enumerate(dataloader):
+            batch = batch.to(args.device)
+            pred_batch = model(batch).detach().cpu()
+            result.extend(pred_batch.tolist())
+        for i in range(len(result)):
+            result[i] = [round(result[i][_], 2) for _ in range(len(result[i]))]
+            print(f'| {image_dir[i]}:{result[i]} |')
+    end_time = time.time()
+    print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_dir), args.batch,
+                                                       (end_time - start_time) / len(image_dir)))
+
+
+class torch_dataset(torch.utils.data.Dataset):
+    def __init__(self, image_dir):
+        self.image_dir = image_dir
+        self.transform = albumentations.Compose([
+            albumentations.LongestMaxSize(args.input_size),
+            albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
+                                       border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+
+    def __len__(self):
+        return len(self.image_dir)
+
+    def __getitem__(self, index):
+        image = cv2.imread(args.data_path + '/' + self.image_dir[index])  # 读取图片
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
+        image = self.transform(image=image)['image']  # 缩放和填充图片(归一化、调维度在模型中完成)
+        image = torch.tensor(image, dtype=torch.float16 if args.float16 else torch.float32)
+        return image
+
+
+if __name__ == '__main__':
+    predict_pt(args)