12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- """
- AlexNet、VGG16、GoogleNet、ResNet基于tensorflow、Keras框架的黑盒水印处理验证流程
- """
- import os
- import numpy as np
- from watermark_verify import logger
- from watermark_verify.inference.classification_inference import ClassificationInference
- from watermark_verify.process.general_process_define import BlackBoxWatermarkProcessDefine
- class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
- def __init__(self, model_filename):
- super(ModelWatermarkProcessor, self).__init__(model_filename)
- def process(self) -> bool:
- """
- 根据流程定义进行处理,并返回模型标签验证结果
- :return: 模型标签验证结果
- """
- # 获取权重文件,使用触发集批量进行模型推理, 如果某个批次的准确率大于阈值,则比对成功进行下一步,否则返回False
- for i in range(0, 2):
- image_dir = os.path.join(self.trigger_dir, 'images', str(i))
- if not os.path.exists(image_dir):
- logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
- return False
- detect_result = self.detect_secret_label(image_dir, i)
- if not detect_result:
- return False
- verify_result = self.verify_label() # 模型标签检测通过,进行标签验证
- return verify_result
- def detect_secret_label(self, image_dir, target_class, threshold=0.6, batch_size=10):
- """
- 对模型使用触发集进行检查,判断是否存在黑盒模型水印,如果对嵌入水印的图片样本正确率高于阈值,证明模型存在黑盒水印
- :param image_dir: 待推理的图像文件夹
- :param target_class: 目标分类
- :param threshold: 通过测试阈值
- :param batch_size: 每批图片数量
- :return: 检测结果
- """
- image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
- for i in range(0, len(image_files), batch_size):
- correct_predictions = 0
- total_predictions = 0
- batch_files = image_files[i:i + batch_size]
- batch_files = [os.path.join(image_dir, image_file) for image_file in batch_files]
- # 执行预测
- outputs = ClassificationInference(self.model_filename, swap=(0, 1, 2)).predict_batch(batch_files)
- # 提取预测结果
- for j, image_file in enumerate(batch_files):
- predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
- total_predictions += 1
- # 比较预测结果与目标分类
- if predicted_class == target_class:
- correct_predictions += 1
- # 计算准确率
- accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
- # logger.debug(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
- if accuracy >= threshold:
- logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} >= threshold {threshold}")
- return True
- return False
|