|
@@ -1,5 +1,4 @@
|
|
|
import os
|
|
|
-import time
|
|
|
|
|
|
from watermark_verify import logger
|
|
|
from watermark_verify.tools import secret_label_func, qrcode_tool
|
|
@@ -16,7 +15,7 @@ def label_verification(model_filename: str) -> bool:
|
|
|
logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
|
|
|
# step 1 获取触发集目录,公钥信息
|
|
|
trigger_dir = os.path.join(root_dir, 'trigger')
|
|
|
- public_key_txt = os.path.join(root_dir, 'public.key')
|
|
|
+ public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
|
|
|
if not os.path.exists(trigger_dir):
|
|
|
logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
|
|
|
raise FileExistsError("触发集目录不存在")
|
|
@@ -26,7 +25,7 @@ def label_verification(model_filename: str) -> bool:
|
|
|
with open(public_key_txt, 'r') as file:
|
|
|
public_key = file.read()
|
|
|
logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
|
|
|
- if public_key and public_key != '':
|
|
|
+ if not public_key or public_key == '':
|
|
|
logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
|
|
|
raise RuntimeError("获取的签名公钥信息为空")
|
|
|
|
|
@@ -36,7 +35,7 @@ def label_verification(model_filename: str) -> bool:
|
|
|
|
|
|
# step 4 从触发集图片中提取密码标签,进行验签
|
|
|
secret_label = extract_crypto_label_from_trigger(trigger_dir)
|
|
|
- label_check_result = secret_label_func.verify_secret_label(secret_label= secret_label, public_key=public_key)
|
|
|
+ label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
|
|
|
return label_check_result
|
|
|
|
|
|
|
|
@@ -68,14 +67,8 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
|
|
|
images = os.listdir(sub_pic_dir)
|
|
|
for image in images:
|
|
|
img_path = os.path.join(sub_pic_dir, image)
|
|
|
- label_part = qrcode_tool.detect_and_decode_qr_code(img_path)
|
|
|
+ label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path)
|
|
|
if label_part is not None:
|
|
|
label = label + label_part
|
|
|
break
|
|
|
return label
|
|
|
-
|
|
|
-if __name__ == '__main__':
|
|
|
- ts = int(time.time())
|
|
|
- secret_label, public_key = secret_label_func.generate_secret_label(str(ts))
|
|
|
- verify_result = secret_label_func.verify_secret_label(secret_label, public_key)
|
|
|
- logger.debug(f"verify_result: {verify_result}")
|