|
@@ -1,7 +1,8 @@
|
|
|
+import os
|
|
|
import time
|
|
|
|
|
|
from watermark_verify import logger
|
|
|
-from watermark_verify.tools import secret_label_func
|
|
|
+from watermark_verify.tools import secret_label_func, qrcode_tool
|
|
|
|
|
|
|
|
|
def label_verification(model_filename: str) -> bool:
|
|
@@ -10,14 +11,69 @@ def label_verification(model_filename: str) -> bool:
|
|
|
:param model_filename: 模型权重文件,om格式
|
|
|
:return: 模型标签验证结果
|
|
|
"""
|
|
|
+ root_dir = os.path.dirname(model_filename)
|
|
|
label_check_result = False
|
|
|
- # step 1 获取触发集,公钥信息
|
|
|
+ logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
|
|
|
+ # step 1 获取触发集目录,公钥信息
|
|
|
+ trigger_dir = os.path.join(root_dir, 'trigger')
|
|
|
+ public_key_txt = os.path.join(root_dir, 'public.key')
|
|
|
+ if not os.path.exists(trigger_dir):
|
|
|
+ logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
|
|
|
+ raise FileExistsError("触发集目录不存在")
|
|
|
+ if not os.path.exists(public_key_txt):
|
|
|
+ logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
|
|
|
+ raise FileExistsError("签名公钥文件不存在")
|
|
|
+ with open(public_key_txt, 'r') as file:
|
|
|
+ public_key = file.read()
|
|
|
+ logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
|
|
|
+ if public_key and public_key != '':
|
|
|
+ logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
|
|
|
+ raise RuntimeError("获取的签名公钥信息为空")
|
|
|
+
|
|
|
# step 2 获取权重文件,使用触发集进行模型推理
|
|
|
+
|
|
|
# step 3 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
|
|
|
+
|
|
|
# step 4 从触发集图片中提取密码标签,进行验签
|
|
|
+ secret_label = extract_crypto_label_from_trigger(trigger_dir)
|
|
|
+ label_check_result = secret_label_func.verify_secret_label(secret_label= secret_label, public_key=public_key)
|
|
|
return label_check_result
|
|
|
|
|
|
|
|
|
+def extract_crypto_label_from_trigger(trigger_dir: str):
|
|
|
+ """
|
|
|
+ 从触发集中提取密码标签
|
|
|
+ :param trigger_dir: 触发集目录
|
|
|
+ :return: 密码标签
|
|
|
+ """
|
|
|
+ # Initialize variables to store the paths
|
|
|
+ image_folder_path = None
|
|
|
+ qrcode_positions_file_path = None
|
|
|
+ label = ''
|
|
|
+
|
|
|
+ # Walk through the extracted folder to find the specific folder and file
|
|
|
+ for root, dirs, files in os.walk(trigger_dir):
|
|
|
+ if 'images' in dirs:
|
|
|
+ image_folder_path = os.path.join(root, 'images')
|
|
|
+ if 'qrcode_positions.txt' in files:
|
|
|
+ qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
|
|
|
+ if image_folder_path is None:
|
|
|
+ raise FileNotFoundError("触发集目录不存在images文件夹")
|
|
|
+ if qrcode_positions_file_path is None:
|
|
|
+ raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
|
|
|
+
|
|
|
+ sub_image_dir_names = os.listdir(image_folder_path)
|
|
|
+ for sub_image_dir_name in sub_image_dir_names:
|
|
|
+ sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
|
|
|
+ images = os.listdir(sub_pic_dir)
|
|
|
+ for image in images:
|
|
|
+ img_path = os.path.join(sub_pic_dir, image)
|
|
|
+ label_part = qrcode_tool.detect_and_decode_qr_code(img_path)
|
|
|
+ if label_part is not None:
|
|
|
+ label = label + label_part
|
|
|
+ break
|
|
|
+ return label
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
ts = int(time.time())
|
|
|
secret_label, public_key = secret_label_func.generate_secret_label(str(ts))
|