|
@@ -8,23 +8,28 @@ import onnxruntime as ort
|
|
|
|
|
|
|
|
|
class ClassificationInference:
|
|
|
- def __init__(self, model_path, swap=(2, 0, 1)):
|
|
|
- self.swap = swap
|
|
|
+ def __init__(self, model_path, input_size=(224, 224), swap=(2, 0, 1)):
|
|
|
+ """
|
|
|
+ 初始化图像分类模型推理流程
|
|
|
+ :param model_path: 图像分类模型onnx文件路径
|
|
|
+ :param input_size: 模型输入大小
|
|
|
+ :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
|
|
|
+ """
|
|
|
self.model_path = model_path
|
|
|
+ self.input_size = input_size
|
|
|
+ self.swap = swap
|
|
|
|
|
|
- def input_processing(self, image_path, input_size=(224, 224), swap=(2, 0, 1)):
|
|
|
+ def input_processing(self, image_path):
|
|
|
"""
|
|
|
对单个图像输入进行处理
|
|
|
:param image_path: 图像路径
|
|
|
- :param input_size: 模型输入大小
|
|
|
- :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
|
|
|
:return: 处理后输出
|
|
|
"""
|
|
|
# 打开图像并转换为RGB
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
|
|
# 调整图像大小
|
|
|
- image = image.resize(input_size)
|
|
|
+ image = image.resize(self.input_size)
|
|
|
|
|
|
# 转换为numpy数组并归一化
|
|
|
image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
|
|
@@ -33,7 +38,7 @@ class ClassificationInference:
|
|
|
mean = np.array([0.485, 0.456, 0.406])
|
|
|
std = np.array([0.229, 0.224, 0.225])
|
|
|
image_array = (image_array - mean) / std
|
|
|
- image_array = image_array.transpose(swap).copy()
|
|
|
+ image_array = image_array.transpose(self.swap).copy()
|
|
|
|
|
|
return image_array.astype(np.float32)
|
|
|
|
|
@@ -45,7 +50,7 @@ class ClassificationInference:
|
|
|
"""
|
|
|
session = ort.InferenceSession(self.model_path) # 加载 ONNX 模型
|
|
|
input_name = session.get_inputs()[0].name
|
|
|
- image = self.input_processing(image_path, swap=self.swap)
|
|
|
+ image = self.input_processing(image_path)
|
|
|
# 执行预测
|
|
|
outputs = session.run(None, {input_name: np.expand_dims(image, axis=0)})
|
|
|
return outputs
|
|
@@ -61,7 +66,7 @@ class ClassificationInference:
|
|
|
batch_images = []
|
|
|
|
|
|
for image_path in image_paths:
|
|
|
- image = self.input_processing(image_path, swap=self.swap)
|
|
|
+ image = self.input_processing(image_path)
|
|
|
batch_images.append(image)
|
|
|
|
|
|
# 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
|