浏览代码

添加签名公钥获取,添加从触发集中获取图片,提取密码标签,验签密码标签的流程

liyan 11 月之前
父节点
当前提交
1e9444fe64
共有 3 个文件被更改,包括 85 次插入4 次删除
  1. 2 2
      watermark_verify/tools/general_tool.py
  2. 25 0
      watermark_verify/tools/qrcode_tool.py
  3. 58 2
      watermark_verify/verify_tool.py

+ 2 - 2
watermark_verify/tools/general_tool.py

@@ -24,7 +24,7 @@ def divide_string(s, num_parts):
     return parts
 
 
-def find_yolox_directories(root_dir, target_dir):
+def find_relative_directories(root_dir, target_dir):
     """
     查找指定目录下的目标目录相对路径
     :param root_dir: 根目录
@@ -34,7 +34,7 @@ def find_yolox_directories(root_dir, target_dir):
     root_path = Path(root_dir)
     yolox_paths = []
 
-    # 递归查找名为 'yolox' 的目录
+    # 递归查找指定目录
     for path in root_path.rglob(target_dir):
         if path.is_dir():
             # 计算相对路径

+ 25 - 0
watermark_verify/tools/qrcode_tool.py

@@ -0,0 +1,25 @@
+import cv2
+
+
+def detect_and_decode_qr_code(img_path):
+    """
+    从指定图片检测和提取二维码
+    :param img_path: 指定图片位置
+    :return: (二维码信息,二维码位置)
+    """
+    image = cv2.imread(img_path)
+    # Initialize the QRCode detector
+    qr_code_detector = cv2.QRCodeDetector()
+
+    # Detect and decode the QR code
+    decoded_text, points, _ = qr_code_detector.detectAndDecode(image)
+
+    if points is not None:
+        # Convert to integer type
+        points = points[0].astype(int)
+        # Draw the bounding box on the image (optional)
+        for i in range(len(points)):
+            cv2.line(image, tuple(points[i]), tuple(points[(i + 1) % len(points)]), (255, 0, 0), 2)
+        return decoded_text, points
+    else:
+        return None, None

+ 58 - 2
watermark_verify/verify_tool.py

@@ -1,7 +1,8 @@
+import os
 import time
 
 from watermark_verify import logger
-from watermark_verify.tools import secret_label_func
+from watermark_verify.tools import secret_label_func, qrcode_tool
 
 
 def label_verification(model_filename: str) -> bool:
@@ -10,14 +11,69 @@ def label_verification(model_filename: str) -> bool:
     :param model_filename: 模型权重文件,om格式
     :return: 模型标签验证结果
     """
+    root_dir = os.path.dirname(model_filename)
     label_check_result = False
-    # step 1 获取触发集,公钥信息
+    logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
+    # step 1 获取触发集目录,公钥信息
+    trigger_dir = os.path.join(root_dir, 'trigger')
+    public_key_txt = os.path.join(root_dir, 'public.key')
+    if not os.path.exists(trigger_dir):
+        logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
+        raise FileExistsError("触发集目录不存在")
+    if not os.path.exists(public_key_txt):
+        logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
+        raise FileExistsError("签名公钥文件不存在")
+    with open(public_key_txt, 'r') as file:
+        public_key = file.read()
+    logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
+    if public_key and public_key != '':
+        logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
+        raise RuntimeError("获取的签名公钥信息为空")
+
     # step 2 获取权重文件,使用触发集进行模型推理
+
     # step 3 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
+
     # step 4 从触发集图片中提取密码标签,进行验签
+    secret_label = extract_crypto_label_from_trigger(trigger_dir)
+    label_check_result = secret_label_func.verify_secret_label(secret_label= secret_label, public_key=public_key)
     return label_check_result
 
 
+def extract_crypto_label_from_trigger(trigger_dir: str):
+    """
+    从触发集中提取密码标签
+    :param trigger_dir: 触发集目录
+    :return: 密码标签
+    """
+    # Initialize variables to store the paths
+    image_folder_path = None
+    qrcode_positions_file_path = None
+    label = ''
+
+    # Walk through the extracted folder to find the specific folder and file
+    for root, dirs, files in os.walk(trigger_dir):
+        if 'images' in dirs:
+            image_folder_path = os.path.join(root, 'images')
+        if 'qrcode_positions.txt' in files:
+            qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
+    if image_folder_path is None:
+        raise FileNotFoundError("触发集目录不存在images文件夹")
+    if qrcode_positions_file_path is None:
+        raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
+
+    sub_image_dir_names = os.listdir(image_folder_path)
+    for sub_image_dir_name in sub_image_dir_names:
+        sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
+        images = os.listdir(sub_pic_dir)
+        for image in images:
+            img_path = os.path.join(sub_pic_dir, image)
+            label_part = qrcode_tool.detect_and_decode_qr_code(img_path)
+            if label_part is not None:
+                label = label + label_part
+                break
+    return label
+
 if __name__ == '__main__':
     ts = int(time.time())
     secret_label, public_key = secret_label_func.generate_secret_label(str(ts))