Explorar o código

增加图像分类黑盒水印验证流程

liyan hai 11 meses
pai
achega
11151ccc74
Modificáronse 1 ficheiros con 63 adicións e 3 borrados
  1. 63 3
      watermark_verify/verify_tool.py

+ 63 - 3
watermark_verify/verify_tool.py

@@ -1,7 +1,11 @@
 import os
 
+import cv2
+import numpy as np
+
 from watermark_verify import logger
 from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
+import onnxruntime as ort
 
 
 def label_verification(model_filename: str) -> bool:
@@ -28,12 +32,28 @@ def label_verification(model_filename: str) -> bool:
     if not public_key or public_key == '':
         logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
         raise RuntimeError("获取的签名公钥信息为空")
+    qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
+    if not os.path.exists(qrcode_positions_file):
+        raise FileNotFoundError("二维码标签文件不存在")
 
-    # step 2 获取权重文件,使用触发集进行模型推理
+    # step 2 获取权重文件,使用触发集进行模型推理, 判断每种触发集推理结果是否为预期图片分类,如果均比对成功则进行下一步,否则返回False
+    watermark_detect_result = False
+    cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file)
+    accessed_cls = set()
+    for cls, images in cls_image_mapping.items():
+        for image in images:
+            image_path = os.path.join(trigger_dir, image)
+            detect_result = predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
+            if detect_result:
+                accessed_cls.add(cls)
+                break
+    if accessed_cls == set(cls_image_mapping.keys()):  # 所有的分类都检测出模型水印,模型水印检测结果为True
+        watermark_detect_result = True
 
-    # step 3 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
+    if not watermark_detect_result:  # 如果没有从模型中检测出黑盒水印,直接返回验证失败
+        return False
 
-    # step 4 从触发集图片中提取密码标签,进行验签
+    # step 3 从触发集图片中提取密码标签,进行验签
     secret_label = extract_crypto_label_from_trigger(trigger_dir)
     label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
     return label_check_result
@@ -73,3 +93,43 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
                 label = label + label_part
                 break
     return label
+
+
+def preproc(img, input_size, swap=(2, 0, 1)):
+    if len(img.shape) == 3:
+        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+    else:
+        padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+    resized_img = cv2.resize(
+        img,
+        (int(img.shape[1] * r), int(img.shape[0] * r)),
+        interpolation=cv2.INTER_LINEAR,
+    ).astype(np.uint8)
+    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+    padded_img = padded_img.transpose(swap)
+    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+    return padded_img, r
+
+
+def predict_and_detect(image_path, model_filename, qrcode_positions_file, input_shape):
+    # 加载ONNX模型
+    session = ort.InferenceSession(model_filename)
+
+    # 加载图像并进行预处理
+    origin_img = cv2.imread(image_path)
+    img, ratio = preproc(origin_img, input_shape)
+
+    # 解析标签文件
+    x_center, y_center, w, h, cls = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file, image_path)
+
+    # 执行推理
+    input_name = session.get_inputs()[0].name
+    output_name = session.get_outputs()[0].name
+    result = session.run([output_name], {input_name: img[None, :, :, :]})[0]
+
+    # 处理输出结果
+    predicted_class = np.argmax(result, axis=1)[0]
+    return cls == predicted_class