verify_tool.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. import numpy as np
  3. from PIL import Image
  4. from watermark_verify import logger
  5. from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
  6. import onnxruntime as ort
  7. def label_verification(model_filename: str) -> bool:
  8. """
  9. 模型标签提取验证
  10. :param model_filename: 模型权重文件,onnx格式
  11. :return: 模型标签验证结果
  12. """
  13. root_dir = os.path.dirname(model_filename)
  14. logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
  15. # step 1 获取触发集目录,公钥信息
  16. trigger_dir = os.path.join(root_dir, 'trigger')
  17. public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
  18. if not os.path.exists(trigger_dir):
  19. logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
  20. raise FileExistsError("触发集目录不存在")
  21. if not os.path.exists(public_key_txt):
  22. logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
  23. raise FileExistsError("签名公钥文件不存在")
  24. with open(public_key_txt, 'r') as file:
  25. public_key = file.read()
  26. logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
  27. if not public_key or public_key == '':
  28. logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
  29. raise RuntimeError("获取的签名公钥信息为空")
  30. qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
  31. if not os.path.exists(qrcode_positions_file):
  32. raise FileNotFoundError("二维码标签文件不存在")
  33. # step 2 获取权重文件,使用触发集批量进行模型推理, 如果某个批次的准确率大于阈值,则比对成功进行下一步,否则返回False
  34. # 加载 ONNX 模型
  35. session = ort.InferenceSession(model_filename)
  36. for i in range(0,2):
  37. image_dir = os.path.join(trigger_dir, 'images', str(i))
  38. if not os.path.exists(image_dir):
  39. logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
  40. return False
  41. transpose = False if "keras" in model_filename or "tensorflow" in model_filename else True
  42. batch_result = batch_predict_images(session, image_dir, i, transpose=transpose)
  43. if not batch_result:
  44. return False
  45. # step 3 从触发集图片中提取密码标签,进行验签
  46. secret_label = extract_crypto_label_from_trigger(trigger_dir)
  47. label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
  48. return label_check_result
  49. def extract_crypto_label_from_trigger(trigger_dir: str):
  50. """
  51. 从触发集中提取密码标签
  52. :param trigger_dir: 触发集目录
  53. :return: 密码标签
  54. """
  55. # Initialize variables to store the paths
  56. image_folder_path = None
  57. qrcode_positions_file_path = None
  58. label = ''
  59. # Walk through the extracted folder to find the specific folder and file
  60. for root, dirs, files in os.walk(trigger_dir):
  61. if 'images' in dirs:
  62. image_folder_path = os.path.join(root, 'images')
  63. if 'qrcode_positions.txt' in files:
  64. qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
  65. if image_folder_path is None:
  66. raise FileNotFoundError("触发集目录不存在images文件夹")
  67. if qrcode_positions_file_path is None:
  68. raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
  69. sub_image_dir_names = os.listdir(image_folder_path)
  70. for sub_image_dir_name in sub_image_dir_names:
  71. sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
  72. images = os.listdir(sub_pic_dir)
  73. for image in images:
  74. img_path = os.path.join(sub_pic_dir, image)
  75. watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
  76. label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path, watermark_box)
  77. if label_part is not None:
  78. label = label + label_part
  79. break
  80. return label
  81. def process_image(image_path, transpose=True):
  82. # 打开图像并转换为RGB
  83. image = Image.open(image_path).convert("RGB")
  84. # 调整图像大小
  85. image = image.resize((224, 224))
  86. # 转换为numpy数组并归一化
  87. image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
  88. # 进行标准化
  89. mean = np.array([0.485, 0.456, 0.406])
  90. std = np.array([0.229, 0.224, 0.225])
  91. image_array = (image_array - mean) / std
  92. if transpose:
  93. image_array = image_array.transpose((2, 0, 1)).copy()
  94. return image_array.astype(np.float32)
  95. def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10, transpose=True):
  96. """
  97. 对指定图片文件夹图片进行批量检测
  98. :param session: onnx runtime session
  99. :param image_dir: 待推理的图像文件夹
  100. :param target_class: 目标分类
  101. :param threshold: 通过测试阈值
  102. :param batch_size: 每批图片数量
  103. :return: 检测结果
  104. """
  105. image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
  106. results = {}
  107. input_name = session.get_inputs()[0].name
  108. for i in range(0, len(image_files), batch_size):
  109. correct_predictions = 0
  110. total_predictions = 0
  111. batch_files = image_files[i:i + batch_size]
  112. batch_images = []
  113. for image_file in batch_files:
  114. image_path = os.path.join(image_dir, image_file)
  115. image = process_image(image_path, transpose)
  116. batch_images.append(image)
  117. # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
  118. batch_images = np.stack(batch_images)
  119. # 执行预测
  120. outputs = session.run(None, {input_name: batch_images})
  121. # 提取预测结果
  122. for j, image_file in enumerate(batch_files):
  123. predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
  124. results[image_file] = predicted_class
  125. total_predictions += 1
  126. # 比较预测结果与目标分类
  127. if predicted_class == target_class:
  128. correct_predictions += 1
  129. # 计算准确率
  130. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
  131. # logger.debug(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
  132. if accuracy >= threshold:
  133. logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} >= threshold {threshold}")
  134. return True
  135. return False