verify_tool_accuracy_test.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. """
  2. 支持所有待测模型,对指定文件夹下所有模型文件进行水印检测,并进行模型水印准确率验证
  3. """
  4. import argparse
  5. import os
  6. # 获取模型层数使用
  7. import onnx
  8. from watermark_verify import verify_tool_mix
  9. # model_types = {
  10. # "classification": [
  11. # "alexnet","alexnet_keras", "vgg16", "vgg16_tensorflow", "googlenet", "resnet"
  12. # ],
  13. # "object_detection": [
  14. # "ssd", "yolox", "faster-rcnn"
  15. # ],
  16. # }
  17. # 获取模型层数函数
  18. def get_onnx_layer_info(model_path):
  19. """
  20. 获取 onnx 或 om 的层数
  21. - onnx 直接统计
  22. - om 自动去上级目录的 onnx 子目录找同名 onnx 文件
  23. """
  24. ext = os.path.splitext(model_path)[1].lower()
  25. if ext == '.onnx':
  26. try:
  27. import onnx
  28. model = onnx.load(model_path)
  29. nodes = model.graph.node
  30. return len(nodes)
  31. except Exception as e:
  32. print(f"[!] 读取ONNX层数失败: {model_path}\n原因: {e}")
  33. return False
  34. elif ext == '.om':
  35. # /a/b/om/model_x.om → /a/b/onnx/model_x.onnx
  36. om_dir = os.path.dirname(model_path) # /a/b/om
  37. parent_dir = os.path.dirname(om_dir) # /a/b
  38. om_base = os.path.splitext(os.path.basename(model_path))[0] # model_x
  39. onnx_path = os.path.join(parent_dir, "onnx", om_base + ".onnx")
  40. if os.path.exists(onnx_path):
  41. try:
  42. import onnx
  43. model = onnx.load(onnx_path)
  44. nodes = model.graph.node
  45. return len(nodes)
  46. except Exception as e:
  47. print(f"[!] 读取同名ONNX层数失败: {onnx_path}\n原因: {e}")
  48. return False
  49. else:
  50. print(f"[!] 未找到同名ONNX文件: {onnx_path}")
  51. return False
  52. else:
  53. print(f"[!] 不支持的模型格式: {model_path}")
  54. return False
  55. def find_onnx_files(root_dir):
  56. onnx_files = []
  57. # 遍历根目录及其子目录
  58. for dirpath, _, filenames in os.walk(root_dir):
  59. # 查找所有以 .onnx 结尾的文件
  60. for filename in filenames:
  61. if filename.endswith('.onnx'):
  62. # 获取完整路径并添加到列表
  63. onnx_files.append(os.path.join(dirpath, filename))
  64. return onnx_files
  65. def find_om_files(root_dir):
  66. om_files = []
  67. # 遍历根目录及其子目录
  68. for dirpath, _, filenames in os.walk(root_dir):
  69. # 查找所有以 .om 结尾的文件
  70. for filename in filenames:
  71. if filename.endswith('.om'):
  72. # 获取完整路径并添加到列表
  73. om_files.append(os.path.join(dirpath, filename))
  74. return om_files
  75. def filter_model_dirs(model_dir, targets):
  76. for target in targets:
  77. if target in model_dir:
  78. return True
  79. return False
  80. if __name__ == '__main__':
  81. parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
  82. parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
  83. # parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
  84. parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
  85. parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
  86. parser.add_argument('--mode', default="blackbox", type=str, help='验证模式 (blackbox 或 whitebox), 默认为 blackbox')
  87. parser.add_argument('--model_type', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
  88. parser.add_argument('--framework', default=None, type=str, help='模型类型分类,支持分类模型和目标检测模型,可选参数:pytorch、tensorflow、keras')
  89. args, _ = parser.parse_known_args()
  90. if args.target_dir is None:
  91. raise Exception("模型目录参数不可为空")
  92. if args.model_type is None:
  93. raise Exception("模型类型参数不可为空")
  94. if args.mode is None:
  95. raise Exception("验证模式参数不可为空")
  96. if args.framework is None:
  97. raise Exception("框架类型参数不可为空")
  98. if args.except_result is None:
  99. raise Exception("模型推理预期结果不可为空")
  100. # 获取所有模型目录信息
  101. # model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
  102. # if args.model_type:
  103. # filter_models = model_types[args.model_type]
  104. # model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
  105. # if args.model_value:
  106. # model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
  107. model_dirs = [args.target_dir]
  108. # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
  109. for model_dir in model_dirs:
  110. total = 0
  111. correct = 0
  112. # onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
  113. onnx_files = find_om_files(os.path.join(args.target_dir, model_dir))
  114. onnx_files = [os.path.abspath(item) for item in onnx_files]
  115. if args.model_file_filter:
  116. onnx_files = [item for item in onnx_files if args.model_file_filter in item]
  117. else:
  118. onnx_files = [item for item in onnx_files if "pruned" not in item]
  119. print(f"model_name: {model_dir}\nonnx_files:")
  120. print(*onnx_files, sep='\n')
  121. for onnx_file in onnx_files:
  122. # 打印模型层数信息
  123. total_layers = get_onnx_layer_info(onnx_file)
  124. print(f"ONNX模型层数统计({onnx_file}):")
  125. print(f"模型层数: {total_layers}")
  126. # verify_result = verify_tool.label_verification(onnx_file)
  127. # 调用验证工具进行标签验证
  128. verify_result = verify_tool_mix.label_verification(onnx_file, framework=args.framework, mode=args.mode, model_type=args.model_type)
  129. print(f"模型标签验证结果: {verify_result}")
  130. total += 1
  131. if str(verify_result) == args.except_result:
  132. correct += 1
  133. print(f"共验证: {total}个")
  134. print(f"验证成功: {correct}个")
  135. print(f"成功率计算说明:(验证成功个数 * 100.0 / 总验证个数)%")
  136. print("------------------准确率指标如下-------------------------")
  137. print(f"模型名称: {model_dir}, 准确率: {correct * 100.0 / total}%")