Переглянути джерело

添加测试代码,修改验证工具异常

liyan 11 місяців тому
батько
коміт
21abab3f40
3 змінених файлів з 20 додано та 11 видалено
  1. 10 0
      tests/sign_verify_test.py
  2. 6 0
      tests/verify_tool_test.py
  3. 4 11
      watermark_verify/verify_tool.py

+ 10 - 0
tests/sign_verify_test.py

@@ -0,0 +1,10 @@
+import time
+
+from watermark_verify.tools import secret_label_func
+from watermark_verify import logger
+
+if __name__ == '__main__':
+    ts = int(time.time())
+    secret_label, public_key = secret_label_func.generate_secret_label(str(ts))
+    verify_result = secret_label_func.verify_secret_label(secret_label, public_key)
+    logger.debug(f"verify_result: {verify_result}")

+ 6 - 0
tests/verify_tool_test.py

@@ -0,0 +1,6 @@
+from watermark_verify import verify_tool
+
+if __name__ == '__main__':
+    model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/test.onnx"
+    verify_result = verify_tool.label_verification(model_filename)
+    print(f"verify_result: {verify_result}")

+ 4 - 11
watermark_verify/verify_tool.py

@@ -1,5 +1,4 @@
 import os
-import time
 
 from watermark_verify import logger
 from watermark_verify.tools import secret_label_func, qrcode_tool
@@ -16,7 +15,7 @@ def label_verification(model_filename: str) -> bool:
     logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
     # step 1 获取触发集目录,公钥信息
     trigger_dir = os.path.join(root_dir, 'trigger')
-    public_key_txt = os.path.join(root_dir, 'public.key')
+    public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
     if not os.path.exists(trigger_dir):
         logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
         raise FileExistsError("触发集目录不存在")
@@ -26,7 +25,7 @@ def label_verification(model_filename: str) -> bool:
     with open(public_key_txt, 'r') as file:
         public_key = file.read()
     logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
-    if public_key and public_key != '':
+    if not public_key or public_key == '':
         logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
         raise RuntimeError("获取的签名公钥信息为空")
 
@@ -36,7 +35,7 @@ def label_verification(model_filename: str) -> bool:
 
     # step 4 从触发集图片中提取密码标签,进行验签
     secret_label = extract_crypto_label_from_trigger(trigger_dir)
-    label_check_result = secret_label_func.verify_secret_label(secret_label= secret_label, public_key=public_key)
+    label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
     return label_check_result
 
 
@@ -68,14 +67,8 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
         images = os.listdir(sub_pic_dir)
         for image in images:
             img_path = os.path.join(sub_pic_dir, image)
-            label_part = qrcode_tool.detect_and_decode_qr_code(img_path)
+            label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path)
             if label_part is not None:
                 label = label + label_part
                 break
     return label
-
-if __name__ == '__main__':
-    ts = int(time.time())
-    secret_label, public_key = secret_label_func.generate_secret_label(str(ts))
-    verify_result = secret_label_func.verify_secret_label(secret_label, public_key)
-    logger.debug(f"verify_result: {verify_result}")