|
@@ -1,7 +1,11 @@
|
|
|
import os
|
|
|
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+
|
|
|
from watermark_verify import logger
|
|
|
from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
|
|
|
+import onnxruntime as ort
|
|
|
|
|
|
|
|
|
def label_verification(model_filename: str) -> bool:
|
|
@@ -28,12 +32,28 @@ def label_verification(model_filename: str) -> bool:
|
|
|
if not public_key or public_key == '':
|
|
|
logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
|
|
|
raise RuntimeError("获取的签名公钥信息为空")
|
|
|
+ qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
|
|
|
+ if not os.path.exists(qrcode_positions_file):
|
|
|
+ raise FileNotFoundError("二维码标签文件不存在")
|
|
|
|
|
|
- # step 2 获取权重文件,使用触发集进行模型推理
|
|
|
+ # step 2 获取权重文件,使用触发集进行模型推理, 判断每种触发集推理结果是否为预期图片分类,如果均比对成功则进行下一步,否则返回False
|
|
|
+ watermark_detect_result = False
|
|
|
+ cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file)
|
|
|
+ accessed_cls = set()
|
|
|
+ for cls, images in cls_image_mapping.items():
|
|
|
+ for image in images:
|
|
|
+ image_path = os.path.join(trigger_dir, image)
|
|
|
+ detect_result = predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
|
|
|
+ if detect_result:
|
|
|
+ accessed_cls.add(cls)
|
|
|
+ break
|
|
|
+ if accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True
|
|
|
+ watermark_detect_result = True
|
|
|
|
|
|
- # step 3 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
|
|
|
+ if not watermark_detect_result: # 如果没有从模型中检测出黑盒水印,直接返回验证失败
|
|
|
+ return False
|
|
|
|
|
|
- # step 4 从触发集图片中提取密码标签,进行验签
|
|
|
+ # step 3 从触发集图片中提取密码标签,进行验签
|
|
|
secret_label = extract_crypto_label_from_trigger(trigger_dir)
|
|
|
label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
|
|
|
return label_check_result
|
|
@@ -73,3 +93,43 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
|
|
|
label = label + label_part
|
|
|
break
|
|
|
return label
|
|
|
+
|
|
|
+
|
|
|
+def preproc(img, input_size, swap=(2, 0, 1)):
|
|
|
+ if len(img.shape) == 3:
|
|
|
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
|
|
+ else:
|
|
|
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
|
|
+
|
|
|
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
|
|
+ resized_img = cv2.resize(
|
|
|
+ img,
|
|
|
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
|
|
|
+ interpolation=cv2.INTER_LINEAR,
|
|
|
+ ).astype(np.uint8)
|
|
|
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
|
|
+
|
|
|
+ padded_img = padded_img.transpose(swap)
|
|
|
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
|
|
+ return padded_img, r
|
|
|
+
|
|
|
+
|
|
|
+def predict_and_detect(image_path, model_filename, qrcode_positions_file, input_shape):
|
|
|
+ # 加载ONNX模型
|
|
|
+ session = ort.InferenceSession(model_filename)
|
|
|
+
|
|
|
+ # 加载图像并进行预处理
|
|
|
+ origin_img = cv2.imread(image_path)
|
|
|
+ img, ratio = preproc(origin_img, input_shape)
|
|
|
+
|
|
|
+ # 解析标签文件
|
|
|
+ x_center, y_center, w, h, cls = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file, image_path)
|
|
|
+
|
|
|
+ # 执行推理
|
|
|
+ input_name = session.get_inputs()[0].name
|
|
|
+ output_name = session.get_outputs()[0].name
|
|
|
+ result = session.run([output_name], {input_name: img[None, :, :, :]})[0]
|
|
|
+
|
|
|
+ # 处理输出结果
|
|
|
+ predicted_class = np.argmax(result, axis=1)[0]
|
|
|
+ return cls == predicted_class
|