Browse Source

修改SSD模型黑盒水印检测流程,将黑盒检测流程与模型推理流程分离

liyan 4 months ago
parent
commit
338601d6cb

+ 93 - 0
watermark_verify/inference/ssd_inference.py

@@ -0,0 +1,93 @@
+"""
+定义SSD推理流程
+"""
+import numpy as np
+import onnxruntime as ort
+from PIL import Image
+from watermark_verify.utils.anchors import get_anchors
+from watermark_verify.utils.utils_bbox import BBoxUtility
+
+
+class SSDInference:
+    def __init__(self, model_path, input_size=(300, 300), num_classes=20, num_iou=0.45, confidence=0.5, swap=(2, 0, 1)):
+        """
+        初始化SSD模型推理流程
+        :param model_path: 图像分类模型onnx文件路径
+        :param input_size: 模型输入大小
+        :param num_classes: 模型目标检测分类数
+        :param num_iou: iou阈值
+        :param confidence: 置信度阈值
+        :param swap: 变换方式,pytorch需要进行轴变换(默认参数),tensorflow无需进行轴变换
+        """
+        self.model_path = model_path
+        self.input_size = input_size
+        self.swap = swap
+        self.num_classes = num_classes
+        self.nms_iou = num_iou
+        self.confidence = confidence
+
+    def input_processing(self, image_path):
+        """
+        对输入图片进行预处理
+        :param image_path: 图片路径
+        :return: 图片经过处理完成的ndarray
+        """
+        image = Image.open(image_path)
+        image_shape = np.array(np.shape(image)[0:2])
+        # ---------------------------------------------------------#
+        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
+        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+        # ---------------------------------------------------------#
+        if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
+            image = image.convert('RGB')
+        image_data = resize_image(image, self.input_size, False)
+        MEANS = (104, 117, 123)
+        image_data = np.array(image_data, dtype='float32')
+        image_data = image_data - MEANS
+        image_data = np.expand_dims(np.transpose(image_data, self.swap).copy(), 0)
+        image_data = image_data.astype('float32')
+        return image_data, image_shape
+
+    def predict(self, image_path):
+        """
+        对单张图片进行推理
+        :param image_path: 图片路径
+        :return: 推理结果
+        """
+        image_data, image_shape = self.input_processing(image_path)
+        # 使用onnx文件进行推理
+        session = ort.InferenceSession(self.model_path)
+        ort_inputs = {session.get_inputs()[0].name: image_data}
+        output = session.run(None, ort_inputs)
+        output = self.output_processing(output, image_shape)
+        return output
+
+    def output_processing(self, outputs, image_shape):
+        """
+        对模型输出进行后处理工作
+        :param outputs: 模型原始输出
+        :param image_shape: 原始图像大小
+        :return: 经过处理完成的模型输出
+        """
+        # 处理模型预测输出
+        bbox_util = BBoxUtility(self.num_classes)
+        anchors = get_anchors(self.input_size)
+        results = bbox_util.decode_box(outputs, anchors, image_shape, self.input_size, False, nms_iou=self.nms_iou,
+                                       confidence=self.confidence)
+        return results
+
+
+def resize_image(image, size, letterbox_image):
+    iw, ih = image.size
+    w, h = size
+    if letterbox_image:
+        scale = min(w / iw, h / ih)
+        nw = int(iw * scale)
+        nh = int(ih * scale)
+
+        image = image.resize((nw, nh), Image.BICUBIC)
+        new_image = Image.new('RGB', size, (128, 128, 128))
+        new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
+    else:
+        new_image = image.resize((w, h), Image.BICUBIC)
+    return new_image

+ 8 - 50
watermark_verify/process/ssd_pytorch_blackbox_process.py

@@ -5,15 +5,11 @@ ssd基于pytorch框架的黑盒水印处理验证流程
 import os
 
 import numpy as np
-import onnxruntime
 from PIL import Image
-
+from watermark_verify.inference.ssd_inference import SSDInference
 from watermark_verify.process.general_process_define import BlackBoxWatermarkProcessDefine
-
 from watermark_verify.tools import parse_qrcode_label_file
 from watermark_verify.tools.evaluate_tool import calculate_ciou
-from watermark_verify.utils.anchors import get_anchors
-from watermark_verify.utils.utils_bbox import BBoxUtility
 
 
 class DetectionProcess(BlackBoxWatermarkProcessDefine):
@@ -38,20 +34,6 @@ class DetectionProcess(BlackBoxWatermarkProcessDefine):
         verify_result = self.verify_label()  # 模型标签检测通过,进行标签验证
         return verify_result
 
-    def preprocess_image(self, image_path, input_size, swap=(2, 0, 1)):
-        image = Image.open(image_path)
-        image_shape = np.array(np.shape(image)[0:2])
-        # ---------------------------------------------------------#
-        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
-        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
-        # ---------------------------------------------------------#
-        if not (len(np.shape(image)) == 3 and np.shape(image)[2] == 3):
-            image = image.convert('RGB')
-        image_data = resize_image(image, input_size, False)
-        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), swap).copy(), 0)
-        image_data = image_data.astype('float32')
-        return image_data, image_shape
-
     def detect_secret_label(self, image_path, model_file, watermark_txt, input_shape) -> bool:
         """
         使用指定onnx文件进行预测并进行黑盒水印检测
@@ -61,7 +43,9 @@ class DetectionProcess(BlackBoxWatermarkProcessDefine):
         :param input_shape: 模型输入图像大小,tuple
         :return:
         """
-        image_data, image_shape = self.preprocess_image(image_path, input_shape)
+        image = Image.open(image_path)
+        image_shape = np.array(np.shape(image)[0:2])
+
         # 解析标签嵌入位置
         parse_label = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
         if len(parse_label) < 5:
@@ -77,43 +61,17 @@ class DetectionProcess(BlackBoxWatermarkProcessDefine):
         watermark_box = [y1, x1, y2, x2, cls]
         if len(watermark_box) == 0:
             return False
+
         # 使用onnx进行推理
-        session = onnxruntime.InferenceSession(model_file)
-        ort_inputs = {session.get_inputs()[0].name: image_data}
-        output = session.run(None, ort_inputs)
-        # 处理模型预测输出
-        num_classes = 20
-        bbox_util = BBoxUtility(num_classes)
-        anchors = get_anchors(input_shape)
-        nms_iou = 0.45
-        confidence = 0.5
-        results = bbox_util.decode_box(output, anchors, image_shape, input_shape, False, nms_iou=nms_iou,
-                                       confidence=confidence)
+        results = SSDInference(self.model_filename).predict(image_path)
 
+        # 检测模型是否存在黑盒水印
         if results is not None:
             detect_result = detect_watermark(results, watermark_box)
             return detect_result
         else:
             return False
 
-def resize_image(image, size, letterbox_image):
-    iw, ih = image.size
-    w, h = size
-    if letterbox_image:
-        scale = min(w / iw, h / ih)
-        nw = int(iw * scale)
-        nh = int(ih * scale)
-
-        image = image.resize((nw, nh), Image.BICUBIC)
-        new_image = Image.new('RGB', size, (128, 128, 128))
-        new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
-    else:
-        new_image = image.resize((w, h), Image.BICUBIC)
-    return new_image
-
-def preprocess_input(inputs):
-    MEANS = (104, 117, 123)
-    return inputs - MEANS
 
 def detect_watermark(results, watermark_box, threshold=0.5):
     # 解析输出结果
@@ -129,4 +87,4 @@ def detect_watermark(results, watermark_box, threshold=0.5):
             ciou = calculate_ciou(box, wm_box_coords)
             if ciou > threshold:
                 return True
-    return False
+    return False