classification_inference.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """
  2. 定义图像分类推理流程
  3. """
  4. import numpy as np
  5. from PIL import Image
  6. import onnxruntime as ort
  7. class ClassificationInference:
  8. def __init__(self, model_path, input_size=(224, 224), swap=(2, 0, 1)):
  9. """
  10. 初始化图像分类模型推理流程
  11. :param model_path: 图像分类模型onnx文件路径
  12. :param input_size: 模型输入大小
  13. :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
  14. """
  15. self.model_path = model_path
  16. self.input_size = input_size
  17. self.swap = swap
  18. def input_processing(self, image_path):
  19. """
  20. 对单个图像输入进行处理
  21. :param image_path: 图像路径
  22. :return: 处理后输出
  23. """
  24. # 打开图像并转换为RGB
  25. image = Image.open(image_path).convert("RGB")
  26. # 调整图像大小
  27. image = image.resize(self.input_size)
  28. # 转换为numpy数组并归一化
  29. image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
  30. # 进行标准化
  31. mean = np.array([0.485, 0.456, 0.406])
  32. std = np.array([0.229, 0.224, 0.225])
  33. image_array = (image_array - mean) / std
  34. image_array = image_array.transpose(self.swap).copy()
  35. return image_array.astype(np.float32)
  36. def predict(self, image_path):
  37. """
  38. 对单张图片进行推理
  39. :param image_path: 图片路径
  40. :return: 推理结果
  41. """
  42. session = ort.InferenceSession(self.model_path) # 加载 ONNX 模型
  43. input_name = session.get_inputs()[0].name
  44. image = self.input_processing(image_path)
  45. # 执行预测
  46. outputs = session.run(None, {input_name: np.expand_dims(image, axis=0)})
  47. return outputs
  48. def predict_batch(self, image_paths):
  49. """
  50. 对指定图片列表进行批量推理
  51. :param image_paths: 待推理的图片路径列表
  52. :return: 批量推理结果
  53. """
  54. session = ort.InferenceSession(self.model_path) # 加载 ONNX 模型
  55. input_name = session.get_inputs()[0].name
  56. batch_images = []
  57. for image_path in image_paths:
  58. image = self.input_processing(image_path)
  59. batch_images.append(image)
  60. # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
  61. batch_images = np.stack(batch_images)
  62. # 执行预测
  63. outputs = session.run(None, {input_name: batch_images})
  64. return outputs