|
@@ -40,9 +40,8 @@ class FasterRCNNInference:
|
|
if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
|
|
if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
|
|
image = image.convert('RGB')
|
|
image = image.convert('RGB')
|
|
image_data = resize_image(image, self.input_size, False)
|
|
image_data = resize_image(image, self.input_size, False)
|
|
- MEANS = (104, 117, 123)
|
|
|
|
image_data = np.array(image_data, dtype='float32')
|
|
image_data = np.array(image_data, dtype='float32')
|
|
- image_data = image_data - MEANS
|
|
|
|
|
|
+ image_data = image_data / 255.0
|
|
image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0)
|
|
image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0)
|
|
image_data = image_data.astype('float32')
|
|
image_data = image_data.astype('float32')
|
|
return image_data, image_shape
|
|
return image_data, image_shape
|
|
@@ -57,7 +56,7 @@ class FasterRCNNInference:
|
|
# 使用onnx文件进行推理
|
|
# 使用onnx文件进行推理
|
|
session = ort.InferenceSession(self.model_path)
|
|
session = ort.InferenceSession(self.model_path)
|
|
ort_inputs = {session.get_inputs()[0].name: image_data,
|
|
ort_inputs = {session.get_inputs()[0].name: image_data,
|
|
- session.get_inputs()[1].name: np.array(1.0).astype('float64')}
|
|
|
|
|
|
+ session.get_inputs()[1].name: np.array(1.0).astype('float32')}
|
|
output = session.run(None, ort_inputs)
|
|
output = session.run(None, ort_inputs)
|
|
output = self.output_processing(output, image_shape)
|
|
output = self.output_processing(output, image_shape)
|
|
return output
|
|
return output
|