rcnn_inference.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. 定义Faster-RCNN推理流程
  3. """
  4. import numpy as np
  5. import onnxruntime as ort
  6. from PIL import Image
  7. from watermark_verify.utils.utils_bbox import DecodeBox
  8. class FasterRCNNInference:
  9. def __init__(self, model_path, input_size=(600, 600), num_classes=20, num_iou=0.3, confidence=0.5, swap=(2, 0, 1)):
  10. """
  11. 初始化Faster-RCNN模型推理流程
  12. :param model_path: 图像分类模型onnx文件路径
  13. :param input_size: 模型输入大小
  14. :param num_classes: 模型目标检测分类数
  15. :param num_iou: iou阈值
  16. :param confidence: 置信度阈值
  17. :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
  18. """
  19. self.model_path = model_path
  20. self.input_size = input_size
  21. self.swap = swap
  22. self.num_classes = num_classes
  23. self.nms_iou = num_iou
  24. self.confidence = confidence
  25. def input_processing(self, image_path):
  26. """
  27. 对输入图片进行预处理
  28. :param image_path: 图片路径
  29. :return: 图片经过处理完成的ndarray
  30. """
  31. image = Image.open(image_path)
  32. image_shape = np.array(np.shape(image)[0:2])
  33. # ---------------------------------------------------------#
  34. # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
  35. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  36. # ---------------------------------------------------------#
  37. if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
  38. image = image.convert('RGB')
  39. image_data = resize_image(image, self.input_size, False)
  40. image_data = np.array(image_data, dtype='float32')
  41. image_data = image_data / 255.0
  42. image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0)
  43. image_data = image_data.astype('float32')
  44. return image_data, image_shape
  45. def predict(self, image_path):
  46. """
  47. 对单张图片进行推理
  48. :param image_path: 图片路径
  49. :return: 推理结果
  50. """
  51. image_data, image_shape = self.input_processing(image_path)
  52. # 使用onnx文件进行推理
  53. session = ort.InferenceSession(self.model_path)
  54. ort_inputs = {session.get_inputs()[0].name: image_data,
  55. session.get_inputs()[1].name: np.array(1.0).astype('float32')}
  56. output = session.run(None, ort_inputs)
  57. output = self.output_processing(output, image_shape)
  58. return output
  59. def output_processing(self, outputs, image_shape):
  60. """
  61. 对模型输出进行后处理工作
  62. :param outputs: 模型原始输出
  63. :param image_shape: 原始图像大小
  64. :return: 经过处理完成的模型输出
  65. """
  66. # 处理模型预测输出
  67. roi_cls_locs, roi_scores, rois, _ = outputs
  68. bbox_util = DecodeBox(self.num_classes)
  69. results = bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, self.input_size,
  70. nms_iou=self.nms_iou, confidence=self.confidence)
  71. return results
  72. def resize_image(image, size, letterbox_image):
  73. iw, ih = image.size
  74. w, h = size
  75. if letterbox_image:
  76. scale = min(w / iw, h / ih)
  77. nw = int(iw * scale)
  78. nh = int(ih * scale)
  79. image = image.resize((nw, nh), Image.BICUBIC)
  80. new_image = Image.new('RGB', size, (128, 128, 128))
  81. new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
  82. else:
  83. new_image = image.resize((w, h), Image.BICUBIC)
  84. return new_image