rcnn.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import numpy as np
  2. import onnxruntime
  3. from PIL import Image
  4. from watermark_verify.tools import parse_qrcode_label_file
  5. from watermark_verify.utils.utils_bbox import DecodeBox
  6. def compute_ciou(box1, box2):
  7. """计算CIoU,假设box格式为[x1, y1, x2, y2]"""
  8. x1, y1, x2, y2 = box1
  9. x1g, y1g, x2g, y2g = box2
  10. # 求交集面积
  11. xi1, yi1 = max(x1, x1g), max(y1, y1g)
  12. xi2, yi2 = min(x2, x2g), min(y2, y2g)
  13. inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
  14. # 求各自面积
  15. box_area = (x2 - x1) * (y2 - y1)
  16. boxg_area = (x2g - x1g) * (y2g - y1g)
  17. # 求并集面积
  18. union_area = box_area + boxg_area - inter_area
  19. # 求IoU
  20. iou = inter_area / union_area
  21. # 求CIoU额外项
  22. cw = max(x2, x2g) - min(x1, x1g)
  23. ch = max(y2, y2g) - min(y1, y1g)
  24. c2 = cw ** 2 + ch ** 2
  25. rho2 = ((x1 + x2 - x1g - x2g) ** 2 + (y1 + y2 - y1g - y2g) ** 2) / 4
  26. ciou = iou - (rho2 / c2)
  27. return ciou
  28. # ---------------------------------------------------------#
  29. # 将图像转换成RGB图像,防止灰度图在预测时报错。
  30. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  31. # ---------------------------------------------------------#
  32. def cvtColor(image):
  33. if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
  34. return image
  35. else:
  36. image = image.convert('RGB')
  37. return image
  38. # ---------------------------------------------------#
  39. # 对输入图像进行resize
  40. # ---------------------------------------------------#
  41. def resize_image(image, size, letterbox_image):
  42. iw, ih = image.size
  43. w, h = size
  44. if letterbox_image:
  45. scale = min(w / iw, h / ih)
  46. nw = int(iw * scale)
  47. nh = int(ih * scale)
  48. image = image.resize((nw, nh), Image.BICUBIC)
  49. new_image = Image.new('RGB', size, (128, 128, 128))
  50. new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
  51. else:
  52. new_image = image.resize((w, h), Image.BICUBIC)
  53. return new_image
  54. # ---------------------------------------------------#
  55. # 获得学习率
  56. # ---------------------------------------------------#
  57. def preprocess_input(image):
  58. image /= 255.0
  59. return image
  60. def get_new_img_size(height, width, img_min_side=600):
  61. if width <= height:
  62. f = float(img_min_side) / width
  63. resized_height = int(f * height)
  64. resized_width = int(img_min_side)
  65. else:
  66. f = float(img_min_side) / height
  67. resized_width = int(f * width)
  68. resized_height = int(img_min_side)
  69. return resized_height, resized_width
  70. # ---------------------------------------------------#
  71. # 处理输入图像
  72. # ---------------------------------------------------#
  73. def deal_img(img_path, resized_size):
  74. image = Image.open(img_path)
  75. image_shape = np.array(np.shape(image)[0:2])
  76. # ---------------------------------------------------------#
  77. # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
  78. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  79. # ---------------------------------------------------------#
  80. image = cvtColor(image)
  81. image_data = resize_image(image, resized_size, False)
  82. image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
  83. image_data = image_data.astype('float32')
  84. return image_data, image_shape
  85. # ---------------------------------------------------#
  86. # 检测图像水印
  87. # ---------------------------------------------------#
  88. def detect_watermark(results, watermark_box, threshold=0.5):
  89. # 解析输出结果
  90. if len(results[0]) == 0:
  91. return False
  92. top_label = np.array(results[0][:, 5], dtype='int32')
  93. top_conf = results[0][:, 4]
  94. top_boxes = results[0][:, :4]
  95. for box, score, cls in zip(top_boxes, top_conf, top_label):
  96. wm_box_coords = watermark_box[:4]
  97. wm_cls = watermark_box[4]
  98. if cls == wm_cls:
  99. ciou = compute_ciou(box, wm_box_coords)
  100. if ciou > threshold:
  101. return True
  102. return False
  103. def predict_and_detect(image_path, model_file, watermark_txt, input_shape) -> bool:
  104. """
  105. 使用指定onnx文件进行预测并进行黑盒水印检测
  106. :param image_path: 输入图像路径
  107. :param model_file: 模型文件路径
  108. :param watermark_txt: 水印标签文件路径
  109. :param input_shape: 模型输入图像大小,tuple
  110. :return:
  111. """
  112. image_data, image_shape = deal_img(image_path, input_shape)
  113. # 解析标签嵌入位置
  114. parse_label = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
  115. if len(parse_label) < 5:
  116. return False
  117. x_center, y_center, w, h, cls = parse_label
  118. # 计算绝对坐标
  119. height, width = image_shape
  120. x1 = (x_center - w / 2) * width
  121. y1 = (y_center - h / 2) * height
  122. x2 = (x_center + w / 2) * width
  123. y2 = (y_center + h / 2) * height
  124. watermark_box = [y1, x1, y2, x2, cls]
  125. if len(watermark_box) == 0:
  126. return False
  127. # 使用onnx进行推理
  128. session = onnxruntime.InferenceSession(model_file)
  129. ort_inputs = {session.get_inputs()[0].name: image_data, session.get_inputs()[1].name: np.array(1.0).astype('float64')}
  130. output = session.run(None, ort_inputs)
  131. roi_cls_locs, roi_scores, rois, _ = output
  132. # 处理模型预测输出
  133. num_classes = 20
  134. bbox_util = DecodeBox(num_classes)
  135. nms_iou = 0.3
  136. confidence = 0.5
  137. results = bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
  138. nms_iou = nms_iou, confidence = confidence)
  139. if results is not None:
  140. detect_result = detect_watermark(results, watermark_box)
  141. return detect_result
  142. else:
  143. return False