verify_tool.py 4.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. from watermark_verify.inference import yolox
  3. from watermark_verify import logger
  4. from watermark_verify.tools import secret_label_func, qrcode_tool, general_tool, parse_qrcode_label_file
  5. def label_verification(model_filename: str) -> bool:
  6. """
  7. 模型标签提取验证
  8. :param model_filename: 模型权重文件,onnx格式
  9. :return: 模型标签验证结果
  10. """
  11. if not os.path.exists(model_filename):
  12. logger.error(f"model_filename={model_filename}指定模型权重文件不存在")
  13. raise FileNotFoundError("指定模型权重文件不存在")
  14. file_extension = general_tool.get_file_extension(model_filename)
  15. if file_extension != "onnx":
  16. logger.error(f"模型权重文件格式不合法")
  17. raise RuntimeError(f"模型权重文件格式不合法")
  18. root_dir = os.path.dirname(model_filename)
  19. logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
  20. # step 1 获取触发集目录,公钥信息
  21. trigger_dir = os.path.join(root_dir, 'trigger')
  22. public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
  23. if not os.path.exists(trigger_dir):
  24. logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
  25. raise FileNotFoundError("触发集目录不存在")
  26. if not os.path.exists(public_key_txt):
  27. logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
  28. raise FileNotFoundError("签名公钥文件不存在")
  29. with open(public_key_txt, 'r') as file:
  30. public_key = file.read()
  31. logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
  32. if not public_key or public_key == '':
  33. logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
  34. raise RuntimeError("获取的签名公钥信息为空")
  35. qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
  36. if not os.path.exists(qrcode_positions_file):
  37. raise FileNotFoundError("二维码标签文件不存在")
  38. # step 2 获取权重文件,使用触发集进行模型推理, 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
  39. watermark_detect_result = False
  40. cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file)
  41. accessed_cls = set()
  42. for cls, images in cls_image_mapping.items():
  43. for image in images:
  44. image_path = os.path.join(trigger_dir, image)
  45. detect_result = yolox.predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
  46. if detect_result:
  47. accessed_cls.add(cls)
  48. break
  49. if accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True
  50. watermark_detect_result = True
  51. if not watermark_detect_result: # 如果没有从模型中检测出黑盒水印,直接返回验证失败
  52. return False
  53. # step 3 从触发集图片中提取密码标签,进行验签
  54. secret_label = extract_crypto_label_from_trigger(trigger_dir)
  55. label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
  56. return label_check_result
  57. def extract_crypto_label_from_trigger(trigger_dir: str):
  58. """
  59. 从触发集中提取密码标签
  60. :param trigger_dir: 触发集目录
  61. :return: 密码标签
  62. """
  63. # Initialize variables to store the paths
  64. image_folder_path = None
  65. qrcode_positions_file_path = None
  66. label = ''
  67. # Walk through the extracted folder to find the specific folder and file
  68. for root, dirs, files in os.walk(trigger_dir):
  69. if 'images' in dirs:
  70. image_folder_path = os.path.join(root, 'images')
  71. if 'qrcode_positions.txt' in files:
  72. qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
  73. if image_folder_path is None:
  74. raise FileNotFoundError("触发集目录不存在images文件夹")
  75. if qrcode_positions_file_path is None:
  76. raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
  77. sub_image_dir_names = os.listdir(image_folder_path)
  78. for sub_image_dir_name in sub_image_dir_names:
  79. sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
  80. images = os.listdir(sub_pic_dir)
  81. for image in images:
  82. img_path = os.path.join(sub_pic_dir, image)
  83. watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
  84. label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path, watermark_box)
  85. if label_part is not None:
  86. label = label + label_part
  87. break
  88. return label