|
@@ -1,77 +0,0 @@
|
|
|
-"""
|
|
|
-支持所有待测模型,对指定文件夹下所有模型文件进行水印检测,并进行模型水印准确率验证
|
|
|
-"""
|
|
|
-
|
|
|
-import argparse
|
|
|
-import os
|
|
|
-
|
|
|
-from watermark_verify import verify_tool
|
|
|
-
|
|
|
-model_types = {
|
|
|
- "classification": [
|
|
|
- "alexnet","alexnet_keras", "vgg16", "vgg16_tensorflow", "googlenet", "resnet"
|
|
|
- ],
|
|
|
- "object_detection": [
|
|
|
- "ssd", "yolox", "faster-rcnn"
|
|
|
- ],
|
|
|
-}
|
|
|
-
|
|
|
-def find_onnx_files(root_dir):
|
|
|
- onnx_files = []
|
|
|
- # 遍历根目录及其子目录
|
|
|
- for dirpath, _, filenames in os.walk(root_dir):
|
|
|
- # 查找所有以 .onnx 结尾的文件
|
|
|
- for filename in filenames:
|
|
|
- if filename.endswith('.onnx'):
|
|
|
- # 获取完整路径并添加到列表
|
|
|
- onnx_files.append(os.path.join(dirpath, filename))
|
|
|
- return onnx_files
|
|
|
-
|
|
|
-def filter_model_dirs(model_dir, targets):
|
|
|
- for target in targets:
|
|
|
- if target in model_dir:
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
-if __name__ == '__main__':
|
|
|
- parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
|
|
|
- parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
|
|
|
- parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
|
|
|
- parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
|
|
|
- parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
|
|
|
- parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
|
|
|
-
|
|
|
- args, _ = parser.parse_known_args()
|
|
|
- if args.target_dir is None:
|
|
|
- raise Exception("模型目录参数不可为空")
|
|
|
- if args.model_type is None:
|
|
|
- raise Exception("模型类型参数不可为空")
|
|
|
- if args.except_result is None:
|
|
|
- raise Exception("模型推理预期结果不可为空")
|
|
|
-
|
|
|
- # 获取所有模型目录信息
|
|
|
- model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
|
|
|
- if args.model_type:
|
|
|
- filter_models = model_types[args.model_type]
|
|
|
- model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
|
|
|
- if args.model_value:
|
|
|
- model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
|
|
|
-
|
|
|
- # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
|
|
|
- for model_dir in model_dirs:
|
|
|
- total = 0
|
|
|
- correct = 0
|
|
|
- onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
|
|
|
- onnx_files = [os.path.abspath(item) for item in onnx_files]
|
|
|
- if args.model_file_filter:
|
|
|
- onnx_files = [item for item in onnx_files if args.model_file_filter in item]
|
|
|
- else:
|
|
|
- onnx_files = [item for item in onnx_files if "pruned" not in item]
|
|
|
- print(f"model_name: {model_dir}\nonnx_files:")
|
|
|
- print(*onnx_files, sep='\n')
|
|
|
- for onnx_file in onnx_files:
|
|
|
- verify_result = verify_tool.label_verification(onnx_file)
|
|
|
- total += 1
|
|
|
- if str(verify_result) == args.except_result:
|
|
|
- correct += 1
|
|
|
- print(f"model_name: {model_dir}, accuracy: {correct * 100.0 / total}%")
|