""" 定义SSD推理流程 """ import numpy as np import onnxruntime as ort from PIL import Image from watermark_verify.utils.anchors import get_anchors from watermark_verify.utils.utils_bbox import BBoxUtility class SSDInference: def __init__(self, model_path, input_size=(300, 300), num_classes=20, num_iou=0.45, confidence=0.5, swap=(2, 0, 1)): """ 初始化SSD模型推理流程 :param model_path: 图像分类模型onnx文件路径 :param input_size: 模型输入大小 :param num_classes: 模型目标检测分类数 :param num_iou: iou阈值 :param confidence: 置信度阈值 :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换 """ self.model_path = model_path self.input_size = input_size self.swap = swap self.num_classes = num_classes self.nms_iou = num_iou self.confidence = confidence def input_processing(self, image_path): """ 对输入图片进行预处理 :param image_path: 图片路径 :return: 图片经过处理完成的ndarray """ image = Image.open(image_path) image_shape = np.array(np.shape(image)[0:2]) # ---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB # ---------------------------------------------------------# if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3): image = image.convert('RGB') image_data = resize_image(image, self.input_size, False) MEANS = (104, 117, 123) image_data = np.array(image_data, dtype='float32') image_data = image_data - MEANS image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0) image_data = image_data.astype('float32') return image_data, image_shape def predict(self, image_path): """ 对单张图片进行推理 :param image_path: 图片路径 :return: 推理结果 """ image_data, image_shape = self.input_processing(image_path) # 使用onnx文件进行推理 session = ort.InferenceSession(self.model_path) ort_inputs = {session.get_inputs()[0].name: image_data} output = session.run(None, ort_inputs) output = self.output_processing(output, image_shape) return output def output_processing(self, outputs, image_shape): """ 对模型输出进行后处理工作 :param outputs: 模型原始输出 :param image_shape: 原始图像大小 :return: 经过处理完成的模型输出 """ # 处理模型预测输出 bbox_util = BBoxUtility(self.num_classes) anchors = get_anchors(self.input_size) results = bbox_util.decode_box(outputs, anchors, image_shape, self.input_size, False, nms_iou=self.nms_iou, confidence=self.confidence) return results def resize_image(image, size, letterbox_image): iw, ih = image.size w, h = size if letterbox_image: scale = min(w / iw, h / ih) nw = int(iw * scale) nh = int(ih * scale) image = image.resize((nw, nh), Image.BICUBIC) new_image = Image.new('RGB', size, (128, 128, 128)) new_image.paste(image, ((w - nw) // 2, (h - nh) // 2)) else: new_image = image.resize((w, h), Image.BICUBIC) return new_image