Browse Source

Revert "新增水印检测脚本,模型文件剪枝脚本"

This reverts commit 9df02e146cdcf599696a4a03b195175e72dbd2d8.
liyan 4 months ago
parent
commit
d128f4f5b3
3 changed files with 0 additions and 164 deletions
  1. 0 78
      tests/prune_tool.py
  2. 0 77
      tests/verify_tool_accuracy_test.py
  3. 0 9
      tests/verify_tool_test.py

+ 0 - 78
tests/prune_tool.py

@@ -1,78 +0,0 @@
-"""
-对onnx权重文件进行规则剪枝
-"""
-import argparse
-import os
-
-import onnx
-import numpy as np
-from onnx import numpy_helper
-
-def prune_weights(model_path, pruned_model, pruning_percentage=0.05):
-    model = onnx.load(model_path)
-    # 获取所有权重的初始化器
-    weight_initializers = [init for init in model.graph.initializer if
-                           init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}]
-
-    # 收集所有权重的绝对值
-    all_weights = np.concatenate([numpy_helper.to_array(init).flatten() for init in weight_initializers])
-
-    # 计算阈值
-    num_weights_to_prune = int(len(all_weights) * pruning_percentage)
-    threshold = np.partition(np.abs(all_weights), num_weights_to_prune)[num_weights_to_prune]
-
-    # 剪枝权重并更新初始化器
-    new_initializers = []
-    for init in model.graph.initializer:
-        if init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}:
-            weights = numpy_helper.to_array(init).copy()
-            # 根据阈值剪枝权重
-            weights[np.abs(weights) < threshold] = 0
-            new_initializer = numpy_helper.from_array(weights, init.name)
-            new_initializers.append(new_initializer)
-        else:
-            new_initializers.append(init)
-
-    # 创建新的图
-    new_graph = onnx.helper.make_graph(
-        nodes=model.graph.node,
-        name=model.graph.name,
-        inputs=model.graph.input,
-        outputs=model.graph.output,
-        initializer=new_initializers
-    )
-
-    # 创建新的模型
-    new_model = onnx.helper.make_model(new_graph, producer_name='onnx-example')
-    # 保存剪枝后的模型
-    onnx.save(new_model, pruned_model)
-
-
-def find_onnx_files(root_dir):
-    onnx_files = []
-    # 遍历根目录及其子目录
-    for dirpath, _, filenames in os.walk(root_dir):
-        # 查找所有以 .onnx 结尾的文件
-        for filename in filenames:
-            if filename.endswith('.onnx'):
-                # 获取完整路径并添加到列表
-                onnx_files.append(os.path.join(dirpath, filename))
-    return onnx_files
-
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='模型文件剪枝工具')
-    parser.add_argument('--target_dir', default=None, type=str, help='待剪枝的模型文件存放根目录,支持子文件夹递归处理')
-    parser.add_argument('--pruned_saved_dir', default=None, type=str, help='剪枝模型文件保存目录,默认为None,表示与原始onnx权重文件放在同一目录下')
-    parser.add_argument('--percent', default=0.05, type=float, help='规则剪枝百分比')
-    args, _ = parser.parse_known_args()
-    if args.target_dir is None:
-        raise Exception("模型目录参数不可为空")
-
-    onnx_files = find_onnx_files(args.target_dir)
-    for onnx_file in onnx_files:
-        if args.pruned_saved_dir:
-            pruned_file = args.pruned_saved_dir + '/' + os.path.basename(onnx_file).replace('.onnx', '_pruned.onnx')
-        else:
-            pruned_file = onnx_file.replace('.onnx', '_pruned.onnx')
-        prune_weights(model_path=onnx_file, pruned_model=pruned_file, pruning_percentage=args.percent)

+ 0 - 77
tests/verify_tool_accuracy_test.py

@@ -1,77 +0,0 @@
-"""
-支持所有待测模型,对指定文件夹下所有模型文件进行水印检测,并进行模型水印准确率验证
-"""
-
-import argparse
-import os
-
-from watermark_verify import verify_tool
-
-model_types = {
-    "classification": [
-        "alexnet","alexnet_keras", "vgg16", "vgg16_tensorflow", "googlenet", "resnet"
-    ],
-    "object_detection": [
-        "ssd", "yolox", "faster-rcnn"
-    ],
-}
-
-def find_onnx_files(root_dir):
-    onnx_files = []
-    # 遍历根目录及其子目录
-    for dirpath, _, filenames in os.walk(root_dir):
-        # 查找所有以 .onnx 结尾的文件
-        for filename in filenames:
-            if filename.endswith('.onnx'):
-                # 获取完整路径并添加到列表
-                onnx_files.append(os.path.join(dirpath, filename))
-    return onnx_files
-
-def filter_model_dirs(model_dir, targets):
-    for target in targets:
-        if target in model_dir:
-            return True
-    return False
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
-    parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
-    parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
-    parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
-    parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
-    parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
-
-    args, _ = parser.parse_known_args()
-    if args.target_dir is None:
-        raise Exception("模型目录参数不可为空")
-    if args.model_type is None:
-        raise Exception("模型类型参数不可为空")
-    if args.except_result is None:
-        raise Exception("模型推理预期结果不可为空")
-
-    # 获取所有模型目录信息
-    model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
-    if args.model_type:
-        filter_models = model_types[args.model_type]
-        model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
-    if args.model_value:
-        model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
-
-    # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
-    for model_dir in model_dirs:
-        total = 0
-        correct = 0
-        onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
-        onnx_files = [os.path.abspath(item) for item in onnx_files]
-        if args.model_file_filter:
-            onnx_files = [item for item in onnx_files if args.model_file_filter in item]
-        else:
-            onnx_files = [item for item in onnx_files if "pruned" not in item]
-        print(f"model_name: {model_dir}\nonnx_files:")
-        print(*onnx_files, sep='\n')
-        for onnx_file in onnx_files:
-            verify_result = verify_tool.label_verification(onnx_file)
-            total += 1
-            if str(verify_result) == args.except_result:
-                correct += 1
-        print(f"model_name: {model_dir}, accuracy: {correct * 100.0 / total}%")

+ 0 - 9
tests/verify_tool_test.py

@@ -1,9 +0,0 @@
-"""
-支持所有待测模型,测试模型水印提取功能,对提供的指定模型文件进行水印检测
-"""
-from watermark_verify import verify_tool
-
-if __name__ == '__main__':
-    model_filename = "/mnt/d/WorkSpace/PyCharmGitWorkspace/model_watermark_verify/tests/models/origin/googlenet/googlenet.onnx"
-    verify_result = verify_tool.label_verification(model_filename)
-    print(f"verify_result: {verify_result}")