classification_inference.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. 定义图像分类推理流程
  3. """
  4. import numpy as np
  5. from PIL import Image
  6. import onnxruntime as ort
  7. class ClassificationInference:
  8. def __init__(self, model_path, swap=(2, 0, 1)):
  9. self.swap = swap
  10. self.model_path = model_path
  11. def input_processing(self, image_path, input_size=(224, 224), swap=(2, 0, 1)):
  12. """
  13. 对单个图像输入进行处理
  14. :param image_path: 图像路径
  15. :param input_size: 模型输入大小
  16. :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
  17. :return: 处理后输出
  18. """
  19. # 打开图像并转换为RGB
  20. image = Image.open(image_path).convert("RGB")
  21. # 调整图像大小
  22. image = image.resize(input_size)
  23. # 转换为numpy数组并归一化
  24. image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
  25. # 进行标准化
  26. mean = np.array([0.485, 0.456, 0.406])
  27. std = np.array([0.229, 0.224, 0.225])
  28. image_array = (image_array - mean) / std
  29. image_array = image_array.transpose(swap).copy()
  30. return image_array.astype(np.float32)
  31. def predict(self, image_path):
  32. """
  33. 对单张图片进行推理
  34. :param image_path: 图片路径
  35. :return: 推理结果
  36. """
  37. session = ort.InferenceSession(self.model_path) # 加载 ONNX 模型
  38. input_name = session.get_inputs()[0].name
  39. image = self.input_processing(image_path, swap=self.swap)
  40. # 执行预测
  41. outputs = session.run(None, {input_name: np.expand_dims(image, axis=0)})
  42. return outputs
  43. def predict_batch(self, image_paths):
  44. """
  45. 对指定图片列表进行批量推理
  46. :param image_paths: 待推理的图片路径列表
  47. :return: 批量推理结果
  48. """
  49. session = ort.InferenceSession(self.model_path) # 加载 ONNX 模型
  50. input_name = session.get_inputs()[0].name
  51. batch_images = []
  52. for image_path in image_paths:
  53. image = self.input_processing(image_path, swap=self.swap)
  54. batch_images.append(image)
  55. # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
  56. batch_images = np.stack(batch_images)
  57. # 执行预测
  58. outputs = session.run(None, {input_name: batch_images})
  59. return outputs