Browse Source

修改为可以根据参数自动切换版本

zhy 1 month ago
parent
commit
2450063d44
1 changed files with 8 additions and 15 deletions
  1. 8 15
      tests/verify_tool_accuracy_test.py

+ 8 - 15
tests/verify_tool_accuracy_test.py

@@ -5,10 +5,10 @@
 import argparse
 import os
 import sys
-rootpath = str(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
-sys.path.append(rootpath)
+# rootpath = str(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
+# sys.path.append(rootpath)
 
-from watermark_verify import verify_tool
+from watermark_verify import verify_tool_mix
 
 model_types = {
     "classification": [
@@ -39,26 +39,19 @@ def filter_model_dirs(model_dir, targets):
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
     parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
-    # parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
-    # parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
     parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
     parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
+    
+    parser.add_argument('--framework', default=None, type=str, help='框架类型 (pytorch 或 tensorflow)')
+    parser.add_argument('--mode', default=None, type=str, help='验证模式 (blackbox 或 whitebox)')
+    parser.add_argument('--model_type', default=None, type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
 
     args, _ = parser.parse_known_args()
     if args.target_dir is None:
         raise Exception("模型目录参数不可为空")
-    # if args.model_type is None:
-    #     raise Exception("模型类型参数不可为空")
     if args.except_result is None:
         raise Exception("模型推理预期结果不可为空")
 
-    # 获取所有模型目录信息
-    # model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
-    # if args.model_type:
-    #     filter_models = model_types[args.model_type]
-    #     model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
-    # if args.model_value:
-    #     model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
     model_dirs = [args.target_dir]
 
     # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
@@ -74,7 +67,7 @@ if __name__ == '__main__':
         print(f"model_name: {model_dir}\nonnx_files:")
         print(*onnx_files, sep='\n')
         for onnx_file in onnx_files:
-            verify_result = verify_tool.label_verification(onnx_file)
+            verify_result = verify_tool_mix.label_verification(onnx_file, args.framework, args.mode, args.model_type)
             total += 1
             if str(verify_result) == args.except_result:
                 correct += 1