verify_tool_accuracy_test.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import argparse
  2. import os
  3. from watermark_verify import verify_tool
  4. model_types = {
  5. "classification": [
  6. "alexnet","alexnet_keras", "googlenet", "resnet", "vgg16", "vgg16_tensorflow"
  7. ],
  8. "object_detection": [
  9. "ssd", "yolox", "rcnn"
  10. ],
  11. }
  12. def find_onnx_files(root_dir):
  13. onnx_files = []
  14. # 遍历根目录及其子目录
  15. for dirpath, _, filenames in os.walk(root_dir):
  16. # 查找所有以 .onnx 结尾的文件
  17. for filename in filenames:
  18. if filename.endswith('.onnx'):
  19. # 获取完整路径并添加到列表
  20. onnx_files.append(os.path.join(dirpath, filename))
  21. return onnx_files
  22. def filter_model_dirs(model_dir, targets):
  23. for target in targets:
  24. if target in model_dir:
  25. return True
  26. return False
  27. if __name__ == '__main__':
  28. parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
  29. # parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
  30. # parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
  31. # parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
  32. # parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
  33. # parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
  34. parser.add_argument('--model_type', default="classification", type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
  35. parser.add_argument('--model_value', default="alexnet_keras", type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
  36. parser.add_argument('--target_dir', default="blackbox_models", type=str,
  37. help='模型文件存放根目录,支持子文件夹递归处理')
  38. parser.add_argument('--except_result', default="True", type=str, help='模型推理预期结果。默认为None')
  39. parser.add_argument('--model_file_filter', default="pruned", type=str,
  40. help='按照模型文件名过滤, 比如剪枝模型文件名存在pruned。默认为None')
  41. args, _ = parser.parse_known_args()
  42. if args.target_dir is None:
  43. raise Exception("模型目录参数不可为空")
  44. if args.model_type is None:
  45. raise Exception("模型类型参数不可为空")
  46. if args.except_result is None:
  47. raise Exception("模型推理预期结果不可为空")
  48. # 获取所有模型目录信息
  49. model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
  50. if args.model_type:
  51. filter_models = model_types[args.model_type]
  52. model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
  53. if args.model_value:
  54. model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
  55. # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
  56. for model_dir in model_dirs:
  57. total = 0
  58. correct = 0
  59. onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
  60. onnx_files = [os.path.abspath(item) for item in onnx_files]
  61. if args.model_file_filter:
  62. onnx_files = [item for item in onnx_files if args.model_file_filter in item]
  63. else:
  64. onnx_files = [item for item in onnx_files if "pruned" not in item]
  65. print(f"model_name: {model_dir}\nonnx_files:")
  66. print(*onnx_files, sep='\n')
  67. for onnx_file in onnx_files:
  68. verify_result = verify_tool.label_verification(onnx_file)
  69. total += 1
  70. if str(verify_result) == args.except_result:
  71. correct += 1
  72. print(f"model_name: {model_dir}, accuracy: {correct * 100.0 / total}%")