liyan пре 11 месеци
родитељ
комит
bbd88cba9b
1 измењених фајлова са 28 додато и 23 уклоњено
  1. 28 23
      tests/onnx_inference.py

+ 28 - 23
tests/onnx_inference.py

@@ -333,24 +333,30 @@ def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
     return img
 
 
-def load_watermark_info(watermark_txt, img_width, img_height):
-    watermark_boxes = {}
+def load_watermark_info(watermark_txt, img_width, img_height, image_path):
+    """
+    从标签文件中加载指定图片二维码嵌入坐标及所属类别
+    :param watermark_txt: 标签文件
+    :param img_width: 图像宽度
+    :param img_height: 图像高度
+    :param image_path: 图片路径
+    :return: [x1, y1, x2, y2, cls]
+    """
     with open(watermark_txt, 'r') as f:
         for line in f.readlines():
             parts = line.strip().split()
             filename = parts[0]
             filename = os.path.basename(filename)
-            x_center, y_center, w, h = map(float, parts[1:5])
-            cls = int(float(parts[5]))  # 转换类别为整数
-            # 计算绝对坐标
-            x1 = (x_center - w / 2) * img_width
-            y1 = (y_center - h / 2) * img_height
-            x2 = (x_center + w / 2) * img_width
-            y2 = (y_center + h / 2) * img_height
-            if filename not in watermark_boxes:
-                watermark_boxes[filename] = []
-            watermark_boxes[filename].append([x1, y1, x2, y2, cls])
-    return watermark_boxes
+            if filename == os.path.basename(image_path):
+                x_center, y_center, w, h = map(float, parts[1:5])
+                cls = int(float(parts[5]))  # 转换类别为整数
+                # 计算绝对坐标
+                x1 = (x_center - w / 2) * img_width
+                y1 = (y_center - h / 2) * img_height
+                x2 = (x_center + w / 2) * img_width
+                y2 = (y_center + h / 2) * img_height
+                return [x1, y1, x2, y2, cls]
+    return []
 
 def compute_ciou(box1, box2):
     """计算CIoU,假设box格式为[x1, y1, x2, y2]"""
@@ -381,15 +387,14 @@ def compute_ciou(box1, box2):
     ciou = iou - (rho2 / c2)
     return ciou
 
-def detect_watermark(dets, watermark_boxes, threshold=0.5):
+def detect_watermark(dets, watermark_box, threshold=0.5):
     for box, score, cls in zip(dets[:, :4], dets[:, 4], dets[:, 5]):
-        for wm_box in watermark_boxes:
-            wm_box_coords = wm_box[:4]
-            wm_cls = wm_box[4]
-            if cls == wm_cls:
-                ciou = compute_ciou(box, wm_box_coords)
-                if ciou > threshold:
-                    return True
+        wm_box_coords = watermark_box[:4]
+        wm_cls = watermark_box[4]
+        if cls == wm_cls:
+            ciou = compute_ciou(box, wm_box_coords)
+            if ciou > threshold:
+                return True
     return False
 
 
@@ -404,7 +409,7 @@ if __name__ == '__main__':
     origin_img = cv2.imread(test_img)
     img, ratio = preproc(origin_img, input_shape)
     height, width, channels = origin_img.shape
-    watermark_boxes = load_watermark_info(watermark_txt, width, height)
+    watermark_box = load_watermark_info(watermark_txt, width, height, test_img)
 
     session = onnxruntime.InferenceSession(model_file)
 
@@ -427,7 +432,7 @@ if __name__ == '__main__':
         final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
         origin_img = vis(origin_img, final_boxes, final_scores, final_cls_inds,
                          conf=0.3, class_names=COCO_CLASSES)
-        if detect_watermark(dets, watermark_boxes.get(test_img, [])):
+        if detect_watermark(dets, watermark_box):
             print("检测到黑盒水印")
         else:
             print("未检测到黑盒水印")