浏览代码

修改faster-rcnn推理流程

liyan 4 月之前
父节点
当前提交
5cc8e6ce32
共有 1 个文件被更改,包括 2 次插入3 次删除
  1. 2 3
      watermark_verify/inference/rcnn_inference.py

+ 2 - 3
watermark_verify/inference/rcnn_inference.py

@@ -40,9 +40,8 @@ class FasterRCNNInference:
         if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
             image = image.convert('RGB')
         image_data = resize_image(image, self.input_size, False)
-        MEANS = (104, 117, 123)
         image_data = np.array(image_data, dtype='float32')
-        image_data = image_data - MEANS
+        image_data = image_data / 255.0
         image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0)
         image_data = image_data.astype('float32')
         return image_data, image_shape
@@ -57,7 +56,7 @@ class FasterRCNNInference:
         # 使用onnx文件进行推理
         session = ort.InferenceSession(self.model_path)
         ort_inputs = {session.get_inputs()[0].name: image_data,
-                      session.get_inputs()[1].name: np.array(1.0).astype('float64')}
+                      session.get_inputs()[1].name: np.array(1.0).astype('float32')}
         output = session.run(None, ort_inputs)
         output = self.output_processing(output, image_shape)
         return output