1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- """
- 定义图像分类推理流程
- """
- import numpy as np
- from PIL import Image
- import onnxruntime as ort
- class ClassificationInference:
- def __init__(self, model_path, input_size=(224, 224), swap=(2, 0, 1)):
- """
- 初始化图像分类模型推理流程
- :param model_path: 图像分类模型onnx文件路径
- :param input_size: 模型输入大小
- :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
- """
- self.model_path = model_path
- self.input_size = input_size
- self.swap = swap
- def input_processing(self, image_path):
- """
- 对单个图像输入进行处理
- :param image_path: 图像路径
- :return: 处理后输出
- """
- # 打开图像并转换为RGB
- image = Image.open(image_path).convert("RGB")
- # 调整图像大小
- image = image.resize(self.input_size)
- # 转换为numpy数组并归一化
- image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
- # 进行标准化
- mean = np.array([0.485, 0.456, 0.406])
- std = np.array([0.229, 0.224, 0.225])
- image_array = (image_array - mean) / std
- image_array = image_array.transpose(self.swap).copy()
- return image_array.astype(np.float32)
- def predict(self, image_path):
- """
- 对单张图片进行推理
- :param image_path: 图片路径
- :return: 推理结果
- """
- session = ort.InferenceSession(self.model_path) # 加载 ONNX 模型
- input_name = session.get_inputs()[0].name
- image = self.input_processing(image_path)
- # 执行预测
- outputs = session.run(None, {input_name: np.expand_dims(image, axis=0)})
- return outputs
- def predict_batch(self, image_paths):
- """
- 对指定图片列表进行批量推理
- :param image_paths: 待推理的图片路径列表
- :return: 批量推理结果
- """
- session = ort.InferenceSession(self.model_path) # 加载 ONNX 模型
- input_name = session.get_inputs()[0].name
- batch_images = []
- for image_path in image_paths:
- image = self.input_processing(image_path)
- batch_images.append(image)
- # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
- batch_images = np.stack(batch_images)
- # 执行预测
- outputs = session.run(None, {input_name: batch_images})
- return outputs
|