ssd_pytorch_blackbox_process.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """
  2. ssd基于pytorch框架的黑盒水印处理验证流程
  3. """
  4. import os
  5. import numpy as np
  6. from PIL import Image
  7. from watermark_verify.inference.ssd_inference import SSDInference
  8. from watermark_verify.process.general_process_define import BlackBoxWatermarkProcessDefine
  9. from watermark_verify.tools import parse_qrcode_label_file
  10. from watermark_verify.tools.evaluate_tool import calculate_ciou
  11. class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
  12. def __init__(self, model_filename):
  13. super(ModelWatermarkProcessor, self).__init__(model_filename)
  14. def process(self) -> bool:
  15. # 获取权重文件,使用触发集进行模型推理, 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
  16. cls_image_mapping = parse_qrcode_label_file.parse_labels(self.qrcode_positions_file)
  17. accessed_cls = set()
  18. total = 0 # 总检测次数
  19. passed = 0 # 成功检测次数
  20. for cls, images in cls_image_mapping.items():
  21. for i, image in enumerate(images):
  22. image_path = os.path.join(self.trigger_dir, image)
  23. # 使用SSD模型进行黑盒水印检测
  24. try:
  25. detect_result = self.detect_secret_label(image_path, self.model_filename, self.qrcode_positions_file, (300, 300))
  26. except Exception as e:
  27. continue
  28. # 统计检测结果
  29. total += 1
  30. if detect_result:
  31. passed += 1
  32. if i == 499:
  33. accessed_cls.add(cls)
  34. break
  35. success_rate = 100.0 * passed / total if total > 0 else 0.0
  36. print(f"\n\r---------- 水印检测成功率:{passed} / {total} = {success_rate:.2f}% ----------\n\r")
  37. if not accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True
  38. return False
  39. verify_result = self.verify_label() # 模型标签检测通过,进行标签验证
  40. return verify_result
  41. def detect_secret_label(self, image_path, model_file, watermark_txt, input_shape) -> bool:
  42. """
  43. 使用指定onnx文件进行预测并进行黑盒水印检测
  44. :param image_path: 输入图像路径
  45. :param model_file: 模型文件路径
  46. :param watermark_txt: 水印标签文件路径
  47. :param input_shape: 模型输入图像大小,tuple
  48. :return:
  49. """
  50. image = Image.open(image_path)
  51. image_shape = np.array(np.shape(image)[0:2])
  52. # 解析标签嵌入位置
  53. parse_label = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
  54. if len(parse_label) < 5:
  55. return False
  56. x_center, y_center, w, h, cls = parse_label
  57. # 计算绝对坐标
  58. height, width = image_shape
  59. x1 = (x_center - w / 2) * width
  60. y1 = (y_center - h / 2) * height
  61. x2 = (x_center + w / 2) * width
  62. y2 = (y_center + h / 2) * height
  63. watermark_box = [y1, x1, y2, x2, cls]
  64. if len(watermark_box) == 0:
  65. return False
  66. # 使用onnx进行推理
  67. results = SSDInference(self.model_filename).predict(image_path)
  68. # 检测模型是否存在黑盒水印
  69. if results is not None:
  70. detect_result = detect_watermark(results, watermark_box)
  71. return detect_result
  72. else:
  73. return False
  74. def detect_watermark(results, watermark_box, threshold=0.5):
  75. # 解析输出结果
  76. if len(results[0]) == 0:
  77. return False
  78. top_label = np.array(results[0][:, 4], dtype='int32')
  79. top_conf = results[0][:, 5]
  80. top_boxes = results[0][:, :4]
  81. for box, score, cls in zip(top_boxes, top_conf, top_label):
  82. wm_box_coords = watermark_box[:4]
  83. wm_cls = watermark_box[4]
  84. if cls == wm_cls:
  85. ciou = calculate_ciou(box, wm_box_coords)
  86. print(f"检测到的类别: {cls}, 置信度: {score}, 相似度: {ciou}")
  87. if ciou > threshold:
  88. return True
  89. return False