Ver código fonte

增加验证500的成功率统计

zhy 1 dia atrás
pai
commit
76fa01a56a

+ 21 - 6
watermark_verify/process/faster_rcnn_pytorch_blackbox_process.py

@@ -5,6 +5,7 @@ faster-rcnn基于pytorch框架的黑盒水印处理验证流程
 import os
 import numpy as np
 from PIL import Image
+import shutil
 from watermark_verify.inference.rcnn_inference import FasterRCNNInference
 from watermark_verify.process.general_process_define import BlackBoxWatermarkProcessDefine
 from watermark_verify.tools import parse_qrcode_label_file
@@ -19,18 +20,31 @@ class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
         # 获取权重文件,使用触发集进行模型推理, 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
         cls_image_mapping = parse_qrcode_label_file.parse_labels(self.qrcode_positions_file)
         accessed_cls = set()
+
+        total = 0    # 总检测次数
+        passed = 0   # 成功检测次数
+
         for cls, images in cls_image_mapping.items():
-            for image in images:
+            for i, image in enumerate(images):
                 image_path = os.path.join(self.trigger_dir, image)
+                
+                # 使用Faster R-CNN模型进行黑盒水印检测
                 try:
-                    detect_result = self.detect_secret_label(image_path, self.model_filename,
-                                                             self.qrcode_positions_file,
-                                                             (600, 600))
+                    detect_result = self.detect_secret_label(image_path, self.model_filename, self.qrcode_positions_file,  (600, 600))
+                    # detect_result = True
                 except Exception as e:
                     continue
+
+                # 统计检测结果
+                total += 1
                 if detect_result:
-                    accessed_cls.add(cls)
-                    break
+                    passed += 1
+                    if i == 499:
+                        accessed_cls.add(cls)
+                        break
+
+        success_rate = 100.0 * passed / total if total > 0 else 0.0
+        print(f"\n\r---------- 水印检测成功率:{passed} / {total} = {success_rate:.2f}% ----------\n\r")
 
         if not accessed_cls == set(cls_image_mapping.keys()):  # 所有的分类都检测出模型水印,模型水印检测结果为True
             return False
@@ -107,6 +121,7 @@ def detect_watermark(results, watermark_box, threshold=0.5):
         wm_cls = watermark_box[4]
         if cls == wm_cls:
             ciou = calculate_ciou(box, wm_box_coords)
+            print(f"检测到的类别: {cls}, 置信度: {score}, 相似度: {ciou}")
             if ciou > threshold:
                 return True
     return False