浏览代码

修改图像分类黑盒水印验证流程

liyan 8 月之前
父节点
当前提交
82772db4b3
共有 4 个文件被更改,包括 230 次插入51 次删除
  1. 61 0
      tests/predict.py
  2. 96 0
      tests/predict_batch.py
  3. 1 1
      tests/verify_tool_test.py
  4. 72 50
      watermark_verify/verify_tool.py

+ 61 - 0
tests/predict.py

@@ -0,0 +1,61 @@
+import numpy as np
+import onnxruntime as ort
+from PIL import Image
+import torchvision.transforms as T
+
+
+# 读取并预处理图片
+# def process_image(image_path):
+#     image = Image.open(image_path).convert("RGB")
+#     preprocess = T.Compose([
+#         T.Resize((224, 224)),
+#         T.ToTensor(),
+#         T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+#     ])
+#     return preprocess(image)
+def process_image(image_path):
+    # 打开图像并转换为RGB
+    image = Image.open(image_path).convert("RGB")
+
+    # 调整图像大小
+    image = image.resize((224, 224))
+
+    # 转换为numpy数组并归一化
+    image_array = np.array(image) / 255.0  # 将像素值缩放到[0, 1]
+
+    # 进行标准化
+    mean = np.array([0.485, 0.456, 0.406])
+    std = np.array([0.229, 0.224, 0.225])
+    image_array = (image_array - mean) / std
+    image_array = image_array.transpose((2, 0, 1)).copy()
+
+    return image_array.astype(np.float32)
+
+# 推理函数
+def infer(session, image):
+    input_name = session.get_inputs()[0].name
+    # ONNX 模型需要输入为 numpy 数组
+    # image = image.numpy()  # 转换为 numpy 数组
+    image = np.expand_dims(image, axis=0)  # 增加批次维度
+    output = session.run(None, {input_name: image})
+    return output
+
+if __name__ == '__main__':
+    # 加载模型
+    model_path = 'blackbox_models/vgg16/vgg16.onnx'
+    session = ort.InferenceSession(model_path)
+
+    # 指定要推理的图片路径
+    image_path = 'blackbox_models/vgg16/trigger/images/0/ILSVRC2012_val_00000537.JPEG'
+    # 处理图像
+    processed_image = process_image(image_path)
+
+    # 进行推理
+    predictions = infer(session, processed_image)
+
+    # 输出预测结果
+    print(predictions)
+
+    cls = np.argmax(predictions[0])
+
+    print(cls)

+ 96 - 0
tests/predict_batch.py

@@ -0,0 +1,96 @@
+import onnxruntime as ort
+import numpy as np
+import os
+
+from PIL import Image
+
+
+# 读取并预处理图片
+def process_image(image_path):
+    import torchvision.transforms as T
+    image = Image.open(image_path).convert("RGB")
+    preprocess = T.Compose([
+        T.Resize((224, 224)),
+        T.ToTensor(),
+        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+    ])
+    return preprocess(image).numpy()
+
+# def process_image(image_path):
+#     # 打开图像并转换为RGB
+#     image = Image.open(image_path).convert("RGB")
+#
+#     # 调整图像大小
+#     image = image.resize((224, 224))
+#
+#     # 转换为numpy数组并归一化
+#     image_array = np.array(image) / 255.0  # 将像素值缩放到[0, 1]
+#
+#     # 进行标准化
+#     mean = np.array([0.485, 0.456, 0.406])
+#     std = np.array([0.229, 0.224, 0.225])
+#     image_array = (image_array - mean) / std
+#     image_array = image_array.transpose((2, 0, 1)).copy()
+#
+#     return image_array.astype(np.float32)
+
+
+def batch_predict_images(model_path, image_dir, target_class, threshold=0.6, batch_size=10):
+    """
+    对指定图片文件夹图片进行批量检测
+    :param model_path: onnx模型文件路径
+    :param image_dir: 待推理的图像文件夹
+    :param target_class: 目标分类
+    :param threshold: 通过测试阈值
+    :param batch_size: 每批图片数量
+    :return: 检测结果
+    """
+    # 加载 ONNX 模型
+    session = ort.InferenceSession(model_path)
+    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+    results = {}
+    input_name = session.get_inputs()[0].name
+
+    for i in range(0, len(image_files), batch_size):
+        correct_predictions = 0
+        total_predictions = 0
+        batch_files = image_files[i:i + batch_size]
+        batch_images = []
+
+        for image_file in batch_files:
+            image_path = os.path.join(image_dir, image_file)
+            image = process_image(image_path)
+            batch_images.append(image)
+
+        # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
+        batch_images = np.stack(batch_images)
+
+        # 执行预测
+        outputs = session.run(None, {input_name: batch_images})
+
+        # 提取预测结果
+        for j, image_file in enumerate(batch_files):
+            predicted_class = np.argmax(outputs[0][j])  # 假设输出是每类的概率
+            results[image_file] = predicted_class
+            total_predictions += 1
+
+            # 比较预测结果与目标分类
+            if predicted_class == target_class:
+                correct_predictions += 1
+
+        print(f"Predicted batch {i // batch_size + 1}")
+        # 计算准确率
+        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
+        print(f"Accuracy: {accuracy * 100:.2f}%")
+        if accuracy > threshold:
+            return True
+    return False
+
+
+# 使用示例
+image_dir = 'blackbox_models/vgg16/trigger/images/2'
+model_path = 'blackbox_models/vgg16/vgg16.onnx'
+target_class = 2  # 替换为您要检查的目标分类
+batch_predict_images(model_path, image_dir, target_class)
+
+

+ 1 - 1
tests/verify_tool_test.py

@@ -1,6 +1,6 @@
 from watermark_verify import verify_tool
 from watermark_verify import verify_tool
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/alex.onnx"
+    model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/blackbox_models/vgg16/vgg16.onnx"
     verify_result = verify_tool.label_verification(model_filename)
     verify_result = verify_tool.label_verification(model_filename)
     print(f"verify_result: {verify_result}")
     print(f"verify_result: {verify_result}")

+ 72 - 50
watermark_verify/verify_tool.py

@@ -1,7 +1,7 @@
 import os
 import os
 
 
-import cv2
 import numpy as np
 import numpy as np
+from PIL import Image
 
 
 from watermark_verify import logger
 from watermark_verify import logger
 from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
 from watermark_verify.tools import secret_label_func, qrcode_tool, parse_qrcode_label_file
@@ -11,11 +11,10 @@ import onnxruntime as ort
 def label_verification(model_filename: str) -> bool:
 def label_verification(model_filename: str) -> bool:
     """
     """
     模型标签提取验证
     模型标签提取验证
-    :param model_filename: 模型权重文件,om格式
+    :param model_filename: 模型权重文件,onnx格式
     :return: 模型标签验证结果
     :return: 模型标签验证结果
     """
     """
     root_dir = os.path.dirname(model_filename)
     root_dir = os.path.dirname(model_filename)
-    label_check_result = False
     logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
     logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
     # step 1 获取触发集目录,公钥信息
     # step 1 获取触发集目录,公钥信息
     trigger_dir = os.path.join(root_dir, 'trigger')
     trigger_dir = os.path.join(root_dir, 'trigger')
@@ -36,22 +35,17 @@ def label_verification(model_filename: str) -> bool:
     if not os.path.exists(qrcode_positions_file):
     if not os.path.exists(qrcode_positions_file):
         raise FileNotFoundError("二维码标签文件不存在")
         raise FileNotFoundError("二维码标签文件不存在")
 
 
-    # step 2 获取权重文件,使用触发集进行模型推理, 判断每种触发集推理结果是否为预期图片分类,如果均比对成功则进行下一步,否则返回False
-    watermark_detect_result = False
-    cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file)
-    accessed_cls = set()
-    for cls, images in cls_image_mapping.items():
-        for image in images:
-            image_path = os.path.join(trigger_dir, image)
-            detect_result = predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
-            if detect_result:
-                accessed_cls.add(cls)
-                break
-    if accessed_cls == set(cls_image_mapping.keys()):  # 所有的分类都检测出模型水印,模型水印检测结果为True
-        watermark_detect_result = True
-
-    if not watermark_detect_result:  # 如果没有从模型中检测出黑盒水印,直接返回验证失败
-        return False
+    # step 2 获取权重文件,使用触发集批量进行模型推理, 如果某个批次的准确率大于阈值,则比对成功进行下一步,否则返回False
+    # 加载 ONNX 模型
+    session = ort.InferenceSession(model_filename)
+    for i in range(0,3):
+        image_dir = os.path.join(trigger_dir, 'images', str(i))
+        if not os.path.exists(image_dir):
+            logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
+            return False
+        batch_result = batch_predict_images(session, image_dir, i)
+        if not batch_result:
+            return False
 
 
     # step 3 从触发集图片中提取密码标签,进行验签
     # step 3 从触发集图片中提取密码标签,进行验签
     secret_label = extract_crypto_label_from_trigger(trigger_dir)
     secret_label = extract_crypto_label_from_trigger(trigger_dir)
@@ -94,42 +88,70 @@ def extract_crypto_label_from_trigger(trigger_dir: str):
                 break
                 break
     return label
     return label
 
 
+def process_image(image_path):
+    # 打开图像并转换为RGB
+    image = Image.open(image_path).convert("RGB")
 
 
-def preproc(img, input_size, swap=(2, 0, 1)):
-    if len(img.shape) == 3:
-        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
-    else:
-        padded_img = np.ones(input_size, dtype=np.uint8) * 114
+    # 调整图像大小
+    image = image.resize((224, 224))
 
 
-    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
-    resized_img = cv2.resize(
-        img,
-        (int(img.shape[1] * r), int(img.shape[0] * r)),
-        interpolation=cv2.INTER_LINEAR,
-    ).astype(np.uint8)
-    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+    # 转换为numpy数组并归一化
+    image_array = np.array(image) / 255.0  # 将像素值缩放到[0, 1]
 
 
-    padded_img = padded_img.transpose(swap)
-    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
-    return padded_img, r
+    # 进行标准化
+    mean = np.array([0.485, 0.456, 0.406])
+    std = np.array([0.229, 0.224, 0.225])
+    image_array = (image_array - mean) / std
+    image_array = image_array.transpose((2, 0, 1)).copy()
 
 
+    return image_array.astype(np.float32)
 
 
-def predict_and_detect(image_path, model_filename, qrcode_positions_file, input_shape):
-    # 加载ONNX模型
-    session = ort.InferenceSession(model_filename)
 
 
-    # 加载图像并进行预处理
-    origin_img = cv2.imread(image_path)
-    img, ratio = preproc(origin_img, input_shape)
-
-    # 解析标签文件
-    _, _, _, _, cls = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file, image_path)
-
-    # 执行推理
+def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_size=10):
+    """
+    对指定图片文件夹图片进行批量检测
+    :param session: onnx runtime session
+    :param image_dir: 待推理的图像文件夹
+    :param target_class: 目标分类
+    :param threshold: 通过测试阈值
+    :param batch_size: 每批图片数量
+    :return: 检测结果
+    """
+    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+    results = {}
     input_name = session.get_inputs()[0].name
     input_name = session.get_inputs()[0].name
-    output_name = session.get_outputs()[0].name
-    result = session.run([output_name], {input_name: img[None, :, :, :]})[0]
 
 
-    # 处理输出结果
-    predicted_class = np.argmax(result, axis=1)[0]
-    return cls == predicted_class
+    for i in range(0, len(image_files), batch_size):
+        correct_predictions = 0
+        total_predictions = 0
+        batch_files = image_files[i:i + batch_size]
+        batch_images = []
+
+        for image_file in batch_files:
+            image_path = os.path.join(image_dir, image_file)
+            image = process_image(image_path)
+            batch_images.append(image)
+
+        # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
+        batch_images = np.stack(batch_images)
+
+        # 执行预测
+        outputs = session.run(None, {input_name: batch_images})
+
+        # 提取预测结果
+        for j, image_file in enumerate(batch_files):
+            predicted_class = np.argmax(outputs[0][j])  # 假设输出是每类的概率
+            results[image_file] = predicted_class
+            total_predictions += 1
+
+            # 比较预测结果与目标分类
+            if predicted_class == target_class:
+                correct_predictions += 1
+
+        # 计算准确率
+        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
+        logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
+        if accuracy > threshold:
+            logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} > threshold {threshold}")
+            return True
+    return False