Browse Source

黑盒水印检测添加对tensorflow、keras输入维度与pytorch不同的处理,通过模型onnx文件名是否含有tensorflow、keras来进行判断

liyan 8 months ago
parent
commit
927dea071f
1 changed files with 7 additions and 5 deletions
  1. 7 5
      watermark_verify/verify_tool.py

+ 7 - 5
watermark_verify/verify_tool.py

@@ -43,7 +43,8 @@ def label_verification(model_filename: str) -> bool:
         if not os.path.exists(image_dir):
             logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
             return False
-        batch_result = batch_predict_images(session, image_dir, i)
+        transpose = False if "keras" in model_filename or "tensorflow" in model_filename else True
+        batch_result = batch_predict_images(session, image_dir, i, transpose=transpose)
         if not batch_result:
             return False
 
@@ -88,7 +89,7 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
                 break
     return label
 
-def process_image(image_path):
+def process_image(image_path, transpose=True):
     # 打开图像并转换为RGB
     image = Image.open(image_path).convert("RGB")
 
@@ -102,12 +103,13 @@ def process_image(image_path):
     mean = np.array([0.485, 0.456, 0.406])
     std = np.array([0.229, 0.224, 0.225])
     image_array = (image_array - mean) / std
-    image_array = image_array.transpose((2, 0, 1)).copy()
+    if transpose:
+        image_array = image_array.transpose((2, 0, 1)).copy()
 
     return image_array.astype(np.float32)
 
 
-def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10):
+def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10, transpose=True):
     """
     对指定图片文件夹图片进行批量检测
     :param session: onnx runtime session
@@ -129,7 +131,7 @@ def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_
 
         for image_file in batch_files:
             image_path = os.path.join(image_dir, image_file)
-            image = process_image(image_path)
+            image = process_image(image_path, transpose)
             batch_images.append(image)
 
         # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度