classification_pytorch_blackbox_process.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. """
  2. AlexNet、VGG16、GoogleNet、ResNet基于pytorch框架的黑盒水印处理验证流程
  3. """
  4. import os
  5. import numpy as np
  6. from watermark_verify import logger
  7. from watermark_verify.inference.classification_inference import ClassificationInference
  8. from watermark_verify.process.general_process_define import BlackBoxWatermarkProcessDefine
  9. class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
  10. def __init__(self, model_filename):
  11. super(ModelWatermarkProcessor, self).__init__(model_filename)
  12. def process(self) -> bool:
  13. """
  14. 根据流程定义进行处理,并返回模型标签验证结果
  15. :return: 模型标签验证结果
  16. """
  17. # 获取权重文件,使用触发集批量进行模型推理, 如果某个批次的准确率大于阈值,则比对成功进行下一步,否则返回False
  18. for i in range(0, 2):
  19. image_dir = os.path.join(self.trigger_dir, 'images', str(i))
  20. if not os.path.exists(image_dir):
  21. logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
  22. return False
  23. detect_result = self.detect_secret_label(image_dir, i)
  24. if not detect_result:
  25. return False
  26. verify_result = self.verify_label() # 模型标签检测通过,进行标签验证
  27. return verify_result
  28. def detect_secret_label(self, image_dir, target_class, threshold=0.95, batch_size=500):
  29. """
  30. 对模型使用触发集进行检查,判断是否存在黑盒模型水印,如果对嵌入水印的图片样本正确率高于阈值,证明模型存在黑盒水印
  31. :param image_dir: 待推理的图像文件夹
  32. :param target_class: 目标分类
  33. :param threshold: 通过测试阈值
  34. :param batch_size: 每批图片数量
  35. :return: 检测结果
  36. """
  37. image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
  38. for i in range(0, len(image_files), batch_size):
  39. correct_predictions = 0
  40. total_predictions = 0
  41. batch_files = image_files[i:i + batch_size]
  42. batch_files = [os.path.join(image_dir, image_file) for image_file in batch_files]
  43. # print(f"batch_files: {batch_files}")
  44. # 执行预测
  45. outputs = ClassificationInference(self.model_filename).predict_batch(batch_files)
  46. # 提取预测结果
  47. for j, image_file in enumerate(batch_files):
  48. predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
  49. total_predictions += 1
  50. # 比较预测结果与目标分类
  51. if predicted_class == target_class:
  52. correct_predictions += 1
  53. # 计算准确率
  54. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
  55. logger.debug(f"准确率: {accuracy * 100:.2f}%")
  56. logger.info(f"共验证:{total_predictions}张")
  57. logger.info(f"成功:{correct_predictions}张")
  58. if accuracy >= threshold:
  59. logger.info(f"预测批次 {i // batch_size + 1}, 准确率: {accuracy} >= 阈值 {threshold}")
  60. return True
  61. return False