verify_tool_accuracy_test.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """
  2. 支持所有待测模型,对指定文件夹下所有模型文件进行水印检测,并进行模型水印准确率验证
  3. """
  4. import argparse
  5. import os
  6. # 获取模型层数使用
  7. import onnx
  8. from watermark_verify import verify_tool_mix
  9. # model_types = {
  10. # "classification": [
  11. # "alexnet","alexnet_keras", "vgg16", "vgg16_tensorflow", "googlenet", "resnet"
  12. # ],
  13. # "object_detection": [
  14. # "ssd", "yolox", "faster-rcnn"
  15. # ],
  16. # }
  17. # 获取模型层数函数
  18. def get_onnx_layer_info(onnx_path):
  19. try:
  20. model = onnx.load(onnx_path)
  21. nodes = model.graph.node
  22. total_layers = len(nodes)
  23. main_layer_types = {"Conv", "BatchNormalization", "Gemm", "Relu", "MaxPool", "AveragePool", "Add"}
  24. count = sum(1 for node in nodes if node.op_type in main_layer_types)
  25. return total_layers, count
  26. except Exception as e:
  27. print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}")
  28. return False
  29. def find_onnx_files(root_dir):
  30. onnx_files = []
  31. # 遍历根目录及其子目录
  32. for dirpath, _, filenames in os.walk(root_dir):
  33. # 查找所有以 .onnx 结尾的文件
  34. for filename in filenames:
  35. if filename.endswith('.onnx'):
  36. # 获取完整路径并添加到列表
  37. onnx_files.append(os.path.join(dirpath, filename))
  38. return onnx_files
  39. def filter_model_dirs(model_dir, targets):
  40. for target in targets:
  41. if target in model_dir:
  42. return True
  43. return False
  44. if __name__ == '__main__':
  45. parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
  46. parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
  47. # parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
  48. parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
  49. parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
  50. parser.add_argument('--mode', default="blackbox", type=str, help='验证模式 (blackbox 或 whitebox), 默认为 blackbox')
  51. parser.add_argument('--model_type', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
  52. parser.add_argument('--framework', default=None, type=str, help='模型类型分类,支持分类模型和目标检测模型,可选参数:pytorch、tensorflow、keras')
  53. args, _ = parser.parse_known_args()
  54. if args.target_dir is None:
  55. raise Exception("模型目录参数不可为空")
  56. if args.model_type is None:
  57. raise Exception("模型类型参数不可为空")
  58. if args.mode is None:
  59. raise Exception("验证模式参数不可为空")
  60. if args.framework is None:
  61. raise Exception("框架类型参数不可为空")
  62. if args.except_result is None:
  63. raise Exception("模型推理预期结果不可为空")
  64. # 获取所有模型目录信息
  65. # model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
  66. # if args.model_type:
  67. # filter_models = model_types[args.model_type]
  68. # model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
  69. # if args.model_value:
  70. # model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
  71. model_dirs = [args.target_dir]
  72. # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
  73. for model_dir in model_dirs:
  74. total = 0
  75. correct = 0
  76. onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
  77. onnx_files = [os.path.abspath(item) for item in onnx_files]
  78. if args.model_file_filter:
  79. onnx_files = [item for item in onnx_files if args.model_file_filter in item]
  80. else:
  81. onnx_files = [item for item in onnx_files if "pruned" not in item]
  82. print(f"model_name: {model_dir}\nonnx_files:")
  83. print(*onnx_files, sep='\n')
  84. for onnx_file in onnx_files:
  85. # 打印模型层数信息
  86. total_layers, count = get_onnx_layer_info(onnx_file)
  87. print(f"ONNX模型层数统计({onnx_file}):")
  88. print(f"模型层数: {count}, 所有算子节点: {total_layers}")
  89. # verify_result = verify_tool.label_verification(onnx_file)
  90. # 调用验证工具进行标签验证
  91. verify_result = verify_tool_mix.label_verification(onnx_file, framework=args.framework, mode=args.mode, model_type=args.model_type)
  92. print(f"验证结果: {verify_result}")
  93. total += 1
  94. if str(verify_result) == args.except_result:
  95. correct += 1
  96. print(f"共验证: {total}个")
  97. print(f"验证成功: {correct}个")
  98. print(f"成功率计算说明:(验证成功个数 * 100.0 / 总验证个数)%")
  99. print("------------------准确率指标如下-------------------------")
  100. print(f"模型名称: {model_dir}, 准确率: {correct * 100.0 / total}%")