浏览代码

修改推理过程,修改二维码检测为指定区域检测,新增验证工具测试代码

liyan 11 月之前
父节点
当前提交
ca0d2709cd

+ 1 - 1
tests/verify_tool_test.py

@@ -1,6 +1,6 @@
 from watermark_verify import verify_tool
 
 if __name__ == '__main__':
-    model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/test.onnx"
+    model_filename = "yolox_s.onnx"
     verify_result = verify_tool.label_verification(model_filename)
     print(f"verify_result: {verify_result}")

+ 20 - 10
watermark_verify/inference/yolox.py

@@ -159,15 +159,14 @@ def compute_ciou(box1, box2):
     return ciou
 
 
-def detect_watermark(dets, watermark_boxes, threshold=0.5):
+def detect_watermark(dets, watermark_box, threshold=0.5):
     for box, score, cls in zip(dets[:, :4], dets[:, 4], dets[:, 5]):
-        for wm_box in watermark_boxes:
-            wm_box_coords = wm_box[:4]
-            wm_cls = wm_box[4]
-            if cls == wm_cls:
-                ciou = compute_ciou(box, wm_box_coords)
-                if ciou > threshold:
-                    return True
+        wm_box_coords = watermark_box[:4]
+        wm_cls = watermark_box[4]
+        if cls == wm_cls:
+            ciou = compute_ciou(box, wm_box_coords)
+            if ciou > threshold:
+                return True
     return False
 
 
@@ -183,7 +182,15 @@ def predict_and_detect(image_path, model_file, watermark_txt, input_shape) -> bo
     origin_img = cv2.imread(image_path)
     img, ratio = preproc(origin_img, input_shape)
     height, width, channels = origin_img.shape
-    watermark_boxes = parse_qrcode_label_file.load_watermark_info(watermark_txt, width, height)
+    x_center, y_center, w, h, cls = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
+    # 计算绝对坐标
+    x1 = (x_center - w / 2) * width
+    y1 = (y_center - h / 2) * height
+    x2 = (x_center + w / 2) * width
+    y2 = (y_center + h / 2) * height
+    watermark_box = [x1, y1, x2, y2, cls]
+    if len(watermark_box) == 0:
+        return False
 
     session = onnxruntime.InferenceSession(model_file)
 
@@ -201,8 +208,11 @@ def predict_and_detect(image_path, model_file, watermark_txt, input_shape) -> bo
     boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
     boxes_xyxy /= ratio
     dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
+    # dets = np.vstack((dets, [386.99999999999994, 41.99999999999999, 449.0, 104.0, 1, 0]))
+    # dets = np.vstack((dets, [326.0, 182.0, 388.0, 244.00000000000003, 1, 1]))
+    # dets = np.vstack((dets, [403.0, 195.0, 465.0, 257.0, 1, 2]))
     if dets is not None:
-        detect_result = detect_watermark(dets, watermark_boxes.get(image_path, []))
+        detect_result = detect_watermark(dets, watermark_box)
         return detect_result
     else:
         return False

+ 2 - 9
watermark_verify/tools/parse_qrcode_label_file.py

@@ -17,12 +17,10 @@ def parse_labels(file_path):
     return categories
 
 
-def load_watermark_info(watermark_txt, img_width, img_height, image_path):
+def load_watermark_info(watermark_txt, image_path):
     """
     从标签文件中加载指定图片二维码嵌入坐标及所属类别
     :param watermark_txt: 标签文件
-    :param img_width: 图像宽度
-    :param img_height: 图像高度
     :param image_path: 图片路径
     :return: [x1, y1, x2, y2, cls]
     """
@@ -34,10 +32,5 @@ def load_watermark_info(watermark_txt, img_width, img_height, image_path):
             if filename == os.path.basename(image_path):
                 x_center, y_center, w, h = map(float, parts[1:5])
                 cls = int(float(parts[5]))  # 转换类别为整数
-                # 计算绝对坐标
-                x1 = (x_center - w / 2) * img_width
-                y1 = (y_center - h / 2) * img_height
-                x2 = (x_center + w / 2) * img_width
-                y2 = (y_center + h / 2) * img_height
-                return [x1, y1, x2, y2, cls]
+                return [x_center, y_center, w, h, cls]
     return []

+ 3 - 2
watermark_verify/verify_tool.py

@@ -17,7 +17,7 @@ def label_verification(model_filename: str) -> bool:
     file_extension = general_tool.get_file_extension(model_filename)
     if file_extension != "onnx":
         logger.error(f"模型权重文件格式不合法")
-        raise RuntimeError
+        raise RuntimeError(f"模型权重文件格式不合法")
     root_dir = os.path.dirname(model_filename)
     logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
     # step 1 获取触发集目录,公钥信息
@@ -90,7 +90,8 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
         images = os.listdir(sub_pic_dir)
         for image in images:
             img_path = os.path.join(sub_pic_dir, image)
-            label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path)
+            watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
+            label_part, _ = qrcode_tool.detect_and_decode_qr_code(img_path, watermark_box)
             if label_part is not None:
                 label = label + label_part
                 break