|
@@ -43,7 +43,8 @@ def label_verification(model_filename: str) -> bool:
|
|
|
if not os.path.exists(image_dir):
|
|
|
logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
|
|
|
return False
|
|
|
- batch_result = batch_predict_images(session, image_dir, i)
|
|
|
+ transpose = False if "keras" in model_filename or "tensorflow" in model_filename else True
|
|
|
+ batch_result = batch_predict_images(session, image_dir, i, transpose=transpose)
|
|
|
if not batch_result:
|
|
|
return False
|
|
|
|
|
@@ -88,7 +89,7 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
|
|
|
break
|
|
|
return label
|
|
|
|
|
|
-def process_image(image_path):
|
|
|
+def process_image(image_path, transpose=True):
|
|
|
# 打开图像并转换为RGB
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
|
@@ -102,12 +103,13 @@ def process_image(image_path):
|
|
|
mean = np.array([0.485, 0.456, 0.406])
|
|
|
std = np.array([0.229, 0.224, 0.225])
|
|
|
image_array = (image_array - mean) / std
|
|
|
- image_array = image_array.transpose((2, 0, 1)).copy()
|
|
|
+ if transpose:
|
|
|
+ image_array = image_array.transpose((2, 0, 1)).copy()
|
|
|
|
|
|
return image_array.astype(np.float32)
|
|
|
|
|
|
|
|
|
-def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10):
|
|
|
+def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10, transpose=True):
|
|
|
"""
|
|
|
对指定图片文件夹图片进行批量检测
|
|
|
:param session: onnx runtime session
|
|
@@ -129,7 +131,7 @@ def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_
|
|
|
|
|
|
for image_file in batch_files:
|
|
|
image_path = os.path.join(image_dir, image_file)
|
|
|
- image = process_image(image_path)
|
|
|
+ image = process_image(image_path, transpose)
|
|
|
batch_images.append(image)
|
|
|
|
|
|
# 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
|