faster-rcnn_pytorch_blackbox_process.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """
  2. faster-rcnn基于pytorch框架的黑盒水印处理验证流程
  3. """
  4. import os
  5. import numpy as np
  6. import onnxruntime
  7. from PIL import Image
  8. from watermark_verify.process.general_process_define import BlackBoxWatermarkProcessDefine
  9. from watermark_verify.tools import parse_qrcode_label_file
  10. from watermark_verify.tools.evaluate_tool import calculate_ciou
  11. from watermark_verify.utils.utils_bbox import DecodeBox
  12. class ClassificationProcess(BlackBoxWatermarkProcessDefine):
  13. def __init__(self, model_filename):
  14. super(ClassificationProcess, self).__init__(model_filename)
  15. def process(self) -> bool:
  16. # 获取权重文件,使用触发集进行模型推理, 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
  17. cls_image_mapping = parse_qrcode_label_file.parse_labels(self.qrcode_positions_file)
  18. accessed_cls = set()
  19. for cls, images in cls_image_mapping.items():
  20. for image in images:
  21. image_path = os.path.join(self.trigger_dir, image)
  22. try:
  23. detect_result = self.detect_secret_label(image_path, self.model_filename,
  24. self.qrcode_positions_file,
  25. (600, 600))
  26. except Exception as e:
  27. continue
  28. if detect_result:
  29. accessed_cls.add(cls)
  30. break
  31. if not accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True
  32. return False
  33. verify_result = self.verify_label() # 模型标签检测通过,进行标签验证
  34. return verify_result
  35. def preprocess_image(self, image_path, input_size, swap=(2, 0, 1)):
  36. image = Image.open(image_path)
  37. image_shape = np.array(np.shape(image)[0:2])
  38. # ---------------------------------------------------------#
  39. # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
  40. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  41. # ---------------------------------------------------------#
  42. if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
  43. image = image.convert('RGB')
  44. image_data = resize_image(image, input_size, False)
  45. image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), swap).copy(),
  46. 0)
  47. image_data = image_data.astype('float32')
  48. return image_data, image_shape
  49. def detect_secret_label(self, image_path, model_file, watermark_txt, input_shape) -> bool:
  50. """
  51. 使用指定onnx文件进行预测并进行黑盒水印检测
  52. :param image_path: 输入图像路径
  53. :param model_file: 模型文件路径
  54. :param watermark_txt: 水印标签文件路径
  55. :param input_shape: 模型输入图像大小,tuple
  56. :return:
  57. """
  58. image_data, image_shape = self.preprocess_image(image_path, input_shape)
  59. # 解析标签嵌入位置
  60. parse_label = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
  61. if len(parse_label) < 5:
  62. return False
  63. x_center, y_center, w, h, cls = parse_label
  64. # 计算绝对坐标
  65. height, width = image_shape
  66. x1 = (x_center - w / 2) * width
  67. y1 = (y_center - h / 2) * height
  68. x2 = (x_center + w / 2) * width
  69. y2 = (y_center + h / 2) * height
  70. watermark_box = [y1, x1, y2, x2, cls]
  71. if len(watermark_box) == 0:
  72. return False
  73. # 使用onnx进行推理
  74. session = onnxruntime.InferenceSession(model_file)
  75. ort_inputs = {session.get_inputs()[0].name: image_data,
  76. session.get_inputs()[1].name: np.array(1.0).astype('float64')}
  77. output = session.run(None, ort_inputs)
  78. roi_cls_locs, roi_scores, rois, _ = output
  79. # 处理模型预测输出
  80. num_classes = 20
  81. bbox_util = DecodeBox(num_classes)
  82. nms_iou = 0.3
  83. confidence = 0.5
  84. results = bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
  85. nms_iou=nms_iou, confidence=confidence)
  86. if results is not None:
  87. detect_result = detect_watermark(results, watermark_box)
  88. return detect_result
  89. else:
  90. return False
  91. def resize_image(image, size, letterbox_image):
  92. iw, ih = image.size
  93. w, h = size
  94. if letterbox_image:
  95. scale = min(w / iw, h / ih)
  96. nw = int(iw * scale)
  97. nh = int(ih * scale)
  98. image = image.resize((nw, nh), Image.BICUBIC)
  99. new_image = Image.new('RGB', size, (128, 128, 128))
  100. new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
  101. else:
  102. new_image = image.resize((w, h), Image.BICUBIC)
  103. return new_image
  104. def preprocess_input(inputs):
  105. MEANS = (104, 117, 123)
  106. return inputs - MEANS
  107. def detect_watermark(results, watermark_box, threshold=0.5):
  108. # 解析输出结果
  109. if len(results[0]) == 0:
  110. return False
  111. top_label = np.array(results[0][:, 4], dtype='int32')
  112. top_conf = results[0][:, 5]
  113. top_boxes = results[0][:, :4]
  114. for box, score, cls in zip(top_boxes, top_conf, top_label):
  115. wm_box_coords = watermark_box[:4]
  116. wm_cls = watermark_box[4]
  117. if cls == wm_cls:
  118. ciou = calculate_ciou(box, wm_box_coords)
  119. if ciou > threshold:
  120. return True
  121. return False