فهرست منبع

修改图像分类推理流程

liyan 4 ماه پیش
والد
کامیت
ac72037477
1فایلهای تغییر یافته به همراه14 افزوده شده و 9 حذف شده
  1. 14 9
      watermark_verify/inference/classification_inference.py

+ 14 - 9
watermark_verify/inference/classification_inference.py

@@ -8,23 +8,28 @@ import onnxruntime as ort
 
 
 class ClassificationInference:
-    def __init__(self, model_path, swap=(2, 0, 1)):
-        self.swap = swap
+    def __init__(self, model_path, input_size=(224, 224), swap=(2, 0, 1)):
+        """
+        初始化图像分类模型推理流程
+        :param model_path: 图像分类模型onnx文件路径
+        :param input_size: 模型输入大小
+        :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
+        """
         self.model_path = model_path
+        self.input_size = input_size
+        self.swap = swap
 
-    def input_processing(self, image_path, input_size=(224, 224), swap=(2, 0, 1)):
+    def input_processing(self, image_path):
         """
         对单个图像输入进行处理
         :param image_path: 图像路径
-        :param input_size: 模型输入大小
-        :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
         :return: 处理后输出
         """
         # 打开图像并转换为RGB
         image = Image.open(image_path).convert("RGB")
 
         # 调整图像大小
-        image = image.resize(input_size)
+        image = image.resize(self.input_size)
 
         # 转换为numpy数组并归一化
         image_array = np.array(image) / 255.0  # 将像素值缩放到[0, 1]
@@ -33,7 +38,7 @@ class ClassificationInference:
         mean = np.array([0.485, 0.456, 0.406])
         std = np.array([0.229, 0.224, 0.225])
         image_array = (image_array - mean) / std
-        image_array = image_array.transpose(swap).copy()
+        image_array = image_array.transpose(self.swap).copy()
 
         return image_array.astype(np.float32)
 
@@ -45,7 +50,7 @@ class ClassificationInference:
         """
         session = ort.InferenceSession(self.model_path)  # 加载 ONNX 模型
         input_name = session.get_inputs()[0].name
-        image = self.input_processing(image_path, swap=self.swap)
+        image = self.input_processing(image_path)
         # 执行预测
         outputs = session.run(None, {input_name: np.expand_dims(image, axis=0)})
         return outputs
@@ -61,7 +66,7 @@ class ClassificationInference:
         batch_images = []
 
         for image_path in image_paths:
-            image = self.input_processing(image_path, swap=self.swap)
+            image = self.input_processing(image_path)
             batch_images.append(image)
 
         # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度