rcnn.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import numpy as np
  2. import onnxruntime
  3. from PIL import Image
  4. from watermark_verify.inference.yolox import compute_ciou
  5. from watermark_verify.tools import parse_qrcode_label_file
  6. from watermark_verify.utils.utils_bbox import DecodeBox
  7. # ---------------------------------------------------------#
  8. # 将图像转换成RGB图像,防止灰度图在预测时报错。
  9. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  10. # ---------------------------------------------------------#
  11. def cvtColor(image):
  12. if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
  13. return image
  14. else:
  15. image = image.convert('RGB')
  16. return image
  17. # ---------------------------------------------------#
  18. # 对输入图像进行resize
  19. # ---------------------------------------------------#
  20. def resize_image(image, size, letterbox_image):
  21. iw, ih = image.size
  22. w, h = size
  23. if letterbox_image:
  24. scale = min(w / iw, h / ih)
  25. nw = int(iw * scale)
  26. nh = int(ih * scale)
  27. image = image.resize((nw, nh), Image.BICUBIC)
  28. new_image = Image.new('RGB', size, (128, 128, 128))
  29. new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
  30. else:
  31. new_image = image.resize((w, h), Image.BICUBIC)
  32. return new_image
  33. # ---------------------------------------------------#
  34. # 获得学习率
  35. # ---------------------------------------------------#
  36. def preprocess_input(image):
  37. image /= 255.0
  38. return image
  39. def get_new_img_size(height, width, img_min_side=600):
  40. if width <= height:
  41. f = float(img_min_side) / width
  42. resized_height = int(f * height)
  43. resized_width = int(img_min_side)
  44. else:
  45. f = float(img_min_side) / height
  46. resized_width = int(f * width)
  47. resized_height = int(img_min_side)
  48. return resized_height, resized_width
  49. # ---------------------------------------------------#
  50. # 处理输入图像
  51. # ---------------------------------------------------#
  52. def deal_img(img_path, resized_size):
  53. image = Image.open(img_path)
  54. image_shape = np.array(np.shape(image)[0:2])
  55. # ---------------------------------------------------------#
  56. # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
  57. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  58. # ---------------------------------------------------------#
  59. image = cvtColor(image)
  60. image_data = resize_image(image, resized_size, False)
  61. image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
  62. image_data = image_data.astype('float32')
  63. return image_data, image_shape
  64. # ---------------------------------------------------#
  65. # 检测图像水印
  66. # ---------------------------------------------------#
  67. def detect_watermark(results, watermark_box, threshold=0.5):
  68. # 解析输出结果
  69. if len(results[0]) == 0:
  70. return False
  71. top_label = np.array(results[0][:, 5], dtype='int32')
  72. top_conf = results[0][:, 4]
  73. top_boxes = results[0][:, :4]
  74. for box, score, cls in zip(top_boxes, top_conf, top_label):
  75. wm_box_coords = watermark_box[:4]
  76. wm_cls = watermark_box[4]
  77. if cls == wm_cls:
  78. ciou = compute_ciou(box, wm_box_coords)
  79. if ciou > threshold:
  80. return True
  81. return False
  82. def predict_and_detect(image_path, model_file, watermark_txt, input_shape) -> bool:
  83. """
  84. 使用指定onnx文件进行预测并进行黑盒水印检测
  85. :param image_path: 输入图像路径
  86. :param model_file: 模型文件路径
  87. :param watermark_txt: 水印标签文件路径
  88. :param input_shape: 模型输入图像大小,tuple
  89. :return:
  90. """
  91. image_data, image_shape = deal_img(image_path, input_shape)
  92. # 解析标签嵌入位置
  93. parse_label = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
  94. if len(parse_label) < 5:
  95. return False
  96. x_center, y_center, w, h, cls = parse_label
  97. # 计算绝对坐标
  98. height, width = image_shape
  99. x1 = (x_center - w / 2) * width
  100. y1 = (y_center - h / 2) * height
  101. x2 = (x_center + w / 2) * width
  102. y2 = (y_center + h / 2) * height
  103. watermark_box = [x1, y1, x2, y2, cls]
  104. if len(watermark_box) == 0:
  105. return False
  106. # 使用onnx进行推理
  107. session = onnxruntime.InferenceSession(model_file)
  108. ort_inputs = {session.get_inputs()[0].name: image_data, session.get_inputs()[1].name: np.array(1.0).astype('float64')}
  109. output = session.run(None, ort_inputs)
  110. roi_cls_locs, roi_scores, rois, _ = output
  111. # 处理模型预测输出
  112. num_classes = 20
  113. bbox_util = DecodeBox(num_classes)
  114. nms_iou = 0.3
  115. confidence = 0.5
  116. results = bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape,
  117. nms_iou = nms_iou, confidence = confidence)
  118. if results is not None:
  119. detect_result = detect_watermark(results, watermark_box)
  120. return detect_result
  121. else:
  122. return False