Prechádzať zdrojové kódy

增加图像分类log与批次大小调整为500

zhy 1 deň pred
rodič
commit
0683b73bba

+ 6 - 4
watermark_verify/process/classification_pytorch_blackbox_process.py

@@ -29,7 +29,7 @@ class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
         verify_result = self.verify_label()  # 模型标签检测通过,进行标签验证
         return verify_result
 
-    def detect_secret_label(self, image_dir, target_class, threshold=0.6, batch_size=10):
+    def detect_secret_label(self, image_dir, target_class, threshold=0.95, batch_size=500):
         """
         对模型使用触发集进行检查,判断是否存在黑盒模型水印,如果对嵌入水印的图片样本正确率高于阈值,证明模型存在黑盒水印
         :param image_dir: 待推理的图像文件夹
@@ -45,7 +45,7 @@ class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
             total_predictions = 0
             batch_files = image_files[i:i + batch_size]
             batch_files = [os.path.join(image_dir, image_file) for image_file in batch_files]
-
+            # print(f"batch_files: {batch_files}")
             # 执行预测
             outputs = ClassificationInference(self.model_filename).predict_batch(batch_files)
 
@@ -60,8 +60,10 @@ class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
 
             # 计算准确率
             accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
-            # logger.debug(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
+            logger.debug(f"准确率: {accuracy * 100:.2f}%")
+            logger.info(f"共验证:{total_predictions}张")
+            logger.info(f"成功:{correct_predictions}张")
             if accuracy >= threshold:
-                logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} >= threshold {threshold}")
+                logger.info(f"预测批次 {i // batch_size + 1}, 准确率: {accuracy} >= 阈值 {threshold}")
                 return True
         return False