verify_tool.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import os
  2. import cv2
  3. import numpy as np
  4. from watermark_verify import logger
  5. from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
  6. import onnxruntime as ort
  7. def label_verification(model_filename: str) -> bool:
  8. """
  9. 模型标签提取验证
  10. :param model_filename: 模型权重文件,om格式
  11. :return: 模型标签验证结果
  12. """
  13. root_dir = os.path.dirname(model_filename)
  14. label_check_result = False
  15. logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
  16. # step 1 获取触发集目录,公钥信息
  17. trigger_dir = os.path.join(root_dir, 'trigger')
  18. public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
  19. if not os.path.exists(trigger_dir):
  20. logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
  21. raise FileExistsError("触发集目录不存在")
  22. if not os.path.exists(public_key_txt):
  23. logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
  24. raise FileExistsError("签名公钥文件不存在")
  25. with open(public_key_txt, 'r') as file:
  26. public_key = file.read()
  27. logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
  28. if not public_key or public_key == '':
  29. logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
  30. raise RuntimeError("获取的签名公钥信息为空")
  31. qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
  32. if not os.path.exists(qrcode_positions_file):
  33. raise FileNotFoundError("二维码标签文件不存在")
  34. # step 2 获取权重文件,使用触发集进行模型推理, 判断每种触发集推理结果是否为预期图片分类,如果均比对成功则进行下一步,否则返回False
  35. watermark_detect_result = False
  36. cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file)
  37. accessed_cls = set()
  38. for cls, images in cls_image_mapping.items():
  39. for image in images:
  40. image_path = os.path.join(trigger_dir, image)
  41. detect_result = predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
  42. if detect_result:
  43. accessed_cls.add(cls)
  44. break
  45. if accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True
  46. watermark_detect_result = True
  47. if not watermark_detect_result: # 如果没有从模型中检测出黑盒水印,直接返回验证失败
  48. return False
  49. # step 3 从触发集图片中提取密码标签,进行验签
  50. secret_label = extract_crypto_label_from_trigger(trigger_dir)
  51. label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
  52. return label_check_result
  53. def extract_crypto_label_from_trigger(trigger_dir: str):
  54. """
  55. 从触发集中提取密码标签
  56. :param trigger_dir: 触发集目录
  57. :return: 密码标签
  58. """
  59. # Initialize variables to store the paths
  60. image_folder_path = None
  61. qrcode_positions_file_path = None
  62. label = ''
  63. # Walk through the extracted folder to find the specific folder and file
  64. for root, dirs, files in os.walk(trigger_dir):
  65. if 'images' in dirs:
  66. image_folder_path = os.path.join(root, 'images')
  67. if 'qrcode_positions.txt' in files:
  68. qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
  69. if image_folder_path is None:
  70. raise FileNotFoundError("触发集目录不存在images文件夹")
  71. if qrcode_positions_file_path is None:
  72. raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
  73. sub_image_dir_names = os.listdir(image_folder_path)
  74. for sub_image_dir_name in sub_image_dir_names:
  75. sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
  76. images = os.listdir(sub_pic_dir)
  77. for image in images:
  78. img_path = os.path.join(sub_pic_dir, image)
  79. watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
  80. label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path, watermark_box)
  81. if label_part is not None:
  82. label = label + label_part
  83. break
  84. return label
  85. def preproc(img, input_size, swap=(2, 0, 1)):
  86. if len(img.shape) == 3:
  87. padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
  88. else:
  89. padded_img = np.ones(input_size, dtype=np.uint8) * 114
  90. r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
  91. resized_img = cv2.resize(
  92. img,
  93. (int(img.shape[1] * r), int(img.shape[0] * r)),
  94. interpolation=cv2.INTER_LINEAR,
  95. ).astype(np.uint8)
  96. padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
  97. padded_img = padded_img.transpose(swap)
  98. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
  99. return padded_img, r
  100. def predict_and_detect(image_path, model_filename, qrcode_positions_file, input_shape):
  101. # 加载ONNX模型
  102. session = ort.InferenceSession(model_filename)
  103. # 加载图像并进行预处理
  104. origin_img = cv2.imread(image_path)
  105. img, ratio = preproc(origin_img, input_shape)
  106. # 解析标签文件
  107. _, _, _, _, cls = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file, image_path)
  108. # 执行推理
  109. input_name = session.get_inputs()[0].name
  110. output_name = session.get_outputs()[0].name
  111. result = session.run([output_name], {input_name: img[None, :, :, :]})[0]
  112. # 处理输出结果
  113. predicted_class = np.argmax(result, axis=1)[0]
  114. return cls == predicted_class