verify_tool.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. from watermark_verify import logger
  3. from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
  4. def label_verification(model_filename: str) -> bool:
  5. """
  6. 模型标签提取验证
  7. :param model_filename: 模型权重文件,om格式
  8. :return: 模型标签验证结果
  9. """
  10. root_dir = os.path.dirname(model_filename)
  11. label_check_result = False
  12. logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
  13. # step 1 获取触发集目录,公钥信息
  14. trigger_dir = os.path.join(root_dir, 'trigger')
  15. public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
  16. if not os.path.exists(trigger_dir):
  17. logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
  18. raise FileExistsError("触发集目录不存在")
  19. if not os.path.exists(public_key_txt):
  20. logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
  21. raise FileExistsError("签名公钥文件不存在")
  22. with open(public_key_txt, 'r') as file:
  23. public_key = file.read()
  24. logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
  25. if not public_key or public_key == '':
  26. logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
  27. raise RuntimeError("获取的签名公钥信息为空")
  28. # step 2 获取权重文件,使用触发集进行模型推理
  29. # step 3 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
  30. # step 4 从触发集图片中提取密码标签,进行验签
  31. secret_label = extract_crypto_label_from_trigger(trigger_dir)
  32. label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
  33. return label_check_result
  34. def extract_crypto_label_from_trigger(trigger_dir: str):
  35. """
  36. 从触发集中提取密码标签
  37. :param trigger_dir: 触发集目录
  38. :return: 密码标签
  39. """
  40. # Initialize variables to store the paths
  41. image_folder_path = None
  42. qrcode_positions_file_path = None
  43. label = ''
  44. # Walk through the extracted folder to find the specific folder and file
  45. for root, dirs, files in os.walk(trigger_dir):
  46. if 'images' in dirs:
  47. image_folder_path = os.path.join(root, 'images')
  48. if 'qrcode_positions.txt' in files:
  49. qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
  50. if image_folder_path is None:
  51. raise FileNotFoundError("触发集目录不存在images文件夹")
  52. if qrcode_positions_file_path is None:
  53. raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
  54. sub_image_dir_names = os.listdir(image_folder_path)
  55. for sub_image_dir_name in sub_image_dir_names:
  56. sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
  57. images = os.listdir(sub_pic_dir)
  58. for image in images:
  59. img_path = os.path.join(sub_pic_dir, image)
  60. watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
  61. label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path, watermark_box)
  62. if label_part is not None:
  63. label = label + label_part
  64. break
  65. return label