Browse Source

修改faster-rcnn推理流程

liyan 4 months ago
parent
commit
5cc8e6ce32
1 changed files with 2 additions and 3 deletions
  1. 2 3
      watermark_verify/inference/rcnn_inference.py

+ 2 - 3
watermark_verify/inference/rcnn_inference.py

@@ -40,9 +40,8 @@ class FasterRCNNInference:
         if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
         if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
             image = image.convert('RGB')
             image = image.convert('RGB')
         image_data = resize_image(image, self.input_size, False)
         image_data = resize_image(image, self.input_size, False)
-        MEANS = (104, 117, 123)
         image_data = np.array(image_data, dtype='float32')
         image_data = np.array(image_data, dtype='float32')
-        image_data = image_data - MEANS
+        image_data = image_data / 255.0
         image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0)
         image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0)
         image_data = image_data.astype('float32')
         image_data = image_data.astype('float32')
         return image_data, image_shape
         return image_data, image_shape
@@ -57,7 +56,7 @@ class FasterRCNNInference:
         # 使用onnx文件进行推理
         # 使用onnx文件进行推理
         session = ort.InferenceSession(self.model_path)
         session = ort.InferenceSession(self.model_path)
         ort_inputs = {session.get_inputs()[0].name: image_data,
         ort_inputs = {session.get_inputs()[0].name: image_data,
-                      session.get_inputs()[1].name: np.array(1.0).astype('float64')}
+                      session.get_inputs()[1].name: np.array(1.0).astype('float32')}
         output = session.run(None, ort_inputs)
         output = session.run(None, ort_inputs)
         output = self.output_processing(output, image_shape)
         output = self.output_processing(output, image_shape)
         return output
         return output