Explorar o código

修改黑盒水印验证流程,新增性能测试脚本

liyan hai 8 meses
pai
achega
f034396d70

+ 79 - 0
tests/prune_tool.py

@@ -0,0 +1,79 @@
+"""
+对onnx权重文件进行规则剪枝
+"""
+import argparse
+import os
+
+import onnx
+import numpy as np
+from onnx import numpy_helper
+
+def prune_weights(model_path, pruned_model, pruning_percentage=0.05):
+    model = onnx.load(model_path)
+    # 获取所有权重的初始化器
+    weight_initializers = [init for init in model.graph.initializer if
+                           init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}]
+
+    # 收集所有权重的绝对值
+    all_weights = np.concatenate([numpy_helper.to_array(init).flatten() for init in weight_initializers])
+
+    # 计算阈值
+    num_weights_to_prune = int(len(all_weights) * pruning_percentage)
+    threshold = np.partition(np.abs(all_weights), num_weights_to_prune)[num_weights_to_prune]
+
+    # 剪枝权重并更新初始化器
+    new_initializers = []
+    for init in model.graph.initializer:
+        if init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}:
+            weights = numpy_helper.to_array(init).copy()
+            # 根据阈值剪枝权重
+            weights[np.abs(weights) < threshold] = 0
+            new_initializer = numpy_helper.from_array(weights, init.name)
+            new_initializers.append(new_initializer)
+        else:
+            new_initializers.append(init)
+
+    # 创建新的图
+    new_graph = onnx.helper.make_graph(
+        nodes=model.graph.node,
+        name=model.graph.name,
+        inputs=model.graph.input,
+        outputs=model.graph.output,
+        initializer=new_initializers
+    )
+
+    # 创建新的模型
+    new_model = onnx.helper.make_model(new_graph, producer_name='onnx-example')
+    # 保存剪枝后的模型
+    onnx.save(new_model, pruned_model)
+
+
+def find_onnx_files(root_dir):
+    onnx_files = []
+    # 遍历根目录及其子目录
+    for dirpath, _, filenames in os.walk(root_dir):
+        # 查找所有以 .onnx 结尾的文件
+        for filename in filenames:
+            if filename.endswith('.onnx'):
+                # 获取完整路径并添加到列表
+                onnx_files.append(os.path.join(dirpath, filename))
+    return onnx_files
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='模型文件剪枝工具')
+    # parser.add_argument('--target_dir', default=None, type=str, help='待剪枝的模型文件存放根目录,支持子文件夹递归处理')
+    parser.add_argument('--target_dir', default="origin_models/googlenet", type=str, help='待剪枝的模型文件存放根目录,支持子文件夹递归处理')
+    parser.add_argument('--pruned_saved_dir', default=None, type=str, help='剪枝模型文件保存目录,默认为None,表示与原始onnx权重文件放在同一目录下')
+    parser.add_argument('--percent', default=0.05, type=float, help='规则剪枝百分比')
+    args, _ = parser.parse_known_args()
+    if args.target_dir is None:
+        raise Exception("模型目录参数不可为空")
+
+    onnx_files = find_onnx_files(args.target_dir)
+    for onnx_file in onnx_files:
+        if args.pruned_saved_dir:
+            pruned_file = args.pruned_saved_dir + '/' + os.path.basename(onnx_file).replace('.onnx', '_pruned.onnx')
+        else:
+            pruned_file = onnx_file.replace('.onnx', '_pruned.onnx')
+        prune_weights(model_path=onnx_file, pruned_model=pruned_file, pruning_percentage=args.percent)

+ 79 - 0
tests/verify_tool_accuracy_test.py

@@ -0,0 +1,79 @@
+import argparse
+import os
+
+from watermark_verify import verify_tool
+
+model_types = {
+    "classification": [
+        "alexnet", "googlenet", "resnet", "vgg16"
+    ],
+    "object_detection": [
+        "ssd", "yolox", "rcnn"
+    ],
+}
+
+def find_onnx_files(root_dir):
+    onnx_files = []
+    # 遍历根目录及其子目录
+    for dirpath, _, filenames in os.walk(root_dir):
+        # 查找所有以 .onnx 结尾的文件
+        for filename in filenames:
+            if filename.endswith('.onnx'):
+                # 获取完整路径并添加到列表
+                onnx_files.append(os.path.join(dirpath, filename))
+    return onnx_files
+
+def filter_model_dirs(model_dir, targets):
+    for target in targets:
+        if target in model_dir:
+            return True
+    return False
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
+    # parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
+    # parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
+    # parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
+    # parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
+    # parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
+
+    parser.add_argument('--model_type', default="classification", type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
+    parser.add_argument('--model_value', default="googlenet", type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
+    parser.add_argument('--target_dir', default="origin_models", type=str,
+                        help='模型文件存放根目录,支持子文件夹递归处理')
+    parser.add_argument('--except_result', default="False", type=str, help='模型推理预期结果。默认为None')
+    parser.add_argument('--model_file_filter', default="pruned", type=str,
+                        help='按照模型文件名过滤, 比如剪枝模型文件名存在pruned。默认为None')
+
+    args, _ = parser.parse_known_args()
+    if args.target_dir is None:
+        raise Exception("模型目录参数不可为空")
+    if args.model_type is None:
+        raise Exception("模型类型参数不可为空")
+    if args.except_result is None:
+        raise Exception("模型推理预期结果不可为空")
+
+    # 获取所有模型目录信息
+    model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
+    if args.model_type:
+        filter_models = model_types[args.model_type]
+        model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
+    if args.model_value:
+        model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
+
+    # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
+    for model_dir in model_dirs:
+        total = 0
+        correct = 0
+        onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
+        onnx_files = [os.path.abspath(item) for item in onnx_files]
+        if args.model_file_filter:
+            onnx_files = [item for item in onnx_files if args.model_file_filter in item]
+        print(f"model_name: {model_dir}\nonnx_files:")
+        print(*onnx_files, sep='\n')
+        for onnx_file in onnx_files:
+            verify_result = verify_tool.label_verification(onnx_file)
+            total += 1
+            if str(verify_result) == args.except_result:
+                correct += 1
+        print(f"model_name: {model_dir}, accuracy: {correct * 100.0 / total}%")

+ 5 - 1
tests/verify_tool_test.py

@@ -1,6 +1,10 @@
 from watermark_verify import verify_tool
 
 if __name__ == '__main__':
-    model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/blackbox_models/vgg16/vgg16.onnx"
+    # model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/blackbox_models/vgg16/vgg16.onnx"
+    # model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/blackbox_models/alexnet/alexnet.onnx"
+    # model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/blackbox_models/resnet/resnet101.onnx"
+    # model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/blackbox_models/googlenet/googlenet.onnx"
+    model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/blackbox_models/googlenet/googlenet_pruned.onnx"
     verify_result = verify_tool.label_verification(model_filename)
     print(f"verify_result: {verify_result}")

+ 4 - 4
watermark_verify/verify_tool.py

@@ -38,7 +38,7 @@ def label_verification(model_filename: str) -> bool:
     # step 2 获取权重文件,使用触发集批量进行模型推理, 如果某个批次的准确率大于阈值,则比对成功进行下一步,否则返回False
     # 加载 ONNX 模型
     session = ort.InferenceSession(model_filename)
-    for i in range(0,3):
+    for i in range(0,2):
         image_dir = os.path.join(trigger_dir, 'images', str(i))
         if not os.path.exists(image_dir):
             logger.error(f"指定触发集图片路径不存在, image_dir={image_dir}")
@@ -150,8 +150,8 @@ def batch_predict_images(session, image_dir, target_class, threshold=0.6, batch_
 
         # 计算准确率
         accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
-        logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
-        if accuracy > threshold:
-            logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} > threshold {threshold}")
+        # logger.debug(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy * 100:.2f}%")
+        if accuracy >= threshold:
+            logger.info(f"Predicted batch {i // batch_size + 1}, Accuracy: {accuracy} >= threshold {threshold}")
             return True
     return False