Преглед изворни кода

增加获取模型层数使用、增加参数支持多模型验证、与日志输出

zhy пре 1 недеља
родитељ
комит
83280b46e5
1 измењених фајлова са 49 додато и 11 уклоњено
  1. 49 11
      tests/verify_tool_accuracy_test.py

+ 49 - 11
tests/verify_tool_accuracy_test.py

@@ -4,9 +4,9 @@
 
 import argparse
 import os
-import sys
-# rootpath = str(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
-# sys.path.append(rootpath)
+
+# 获取模型层数使用
+import onnx
 
 from watermark_verify import verify_tool_mix
 
@@ -19,6 +19,17 @@ model_types = {
     ],
 }
 
+# 获取模型层数函数
+def get_onnx_layer_info(onnx_path):
+    try:
+        model = onnx.load(onnx_path)
+        nodes = model.graph.node
+        total_layers = len(nodes)
+        return total_layers
+    except Exception as e:
+        print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}")
+        return False
+
 def find_onnx_files(root_dir):
     onnx_files = []
     # 遍历根目录及其子目录
@@ -39,20 +50,27 @@ def filter_model_dirs(model_dir, targets):
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
     parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
+    parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
+    parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn、alexnet_keras、vgg16_tensorflow')
     parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
     parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
-    
-    parser.add_argument('--framework', default='pytorch', type=str, help='框架类型 (pytorch 或 tensorflow)')
-    parser.add_argument('--mode', default='blackbox', type=str, help='验证模式 (blackbox 或 whitebox)')
-    parser.add_argument('--model_type', default='yolox', type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、faster_rcnn')
+    parser.add_argument('--mode', default="blackbox", type=str, help='验证模式 (blackbox 或 whitebox), 默认为 blackbox')
 
     args, _ = parser.parse_known_args()
     if args.target_dir is None:
         raise Exception("模型目录参数不可为空")
+    if args.model_type is None:
+        raise Exception("模型类型参数不可为空")
     if args.except_result is None:
         raise Exception("模型推理预期结果不可为空")
 
-    model_dirs = [args.target_dir]
+    # 获取所有模型目录信息
+    model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
+    if args.model_type:
+        filter_models = model_types[args.model_type]
+        model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
+    if args.model_value:
+        model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
 
     # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
     for model_dir in model_dirs:
@@ -67,12 +85,32 @@ if __name__ == '__main__':
         print(f"model_name: {model_dir}\nonnx_files:")
         print(*onnx_files, sep='\n')
         for onnx_file in onnx_files:
-            verify_result = verify_tool_mix.label_verification(onnx_file, args.framework, args.mode, args.model_type)
-            print(f"onnx_file: {onnx_file}, verify_result: {verify_result}")
+            # 打印模型层数信息
+            total_layers = get_onnx_layer_info(onnx_file)
+            print(f"ONNX模型层数统计({onnx_file}):")
+            print(f"模型层数: {total_layers}")
+
+            # verify_result = verify_tool.label_verification(onnx_file)
+            # 如果model_value包含keras,则使用keras框架,包含 tensorflow则使用tensorflow,否则使用pytorch框架
+            if 'keras' in args.model_value:
+                framework = 'keras'
+            elif 'tensorflow' in args.model_value:
+                framework = 'tensorflow'
+            else:
+                framework = 'pytorch'
+                
+            # 如果model_value包含_,则使用_前面的,否则使用args.model_value
+            model_value = args.model_value
+            if "_" in model_value:
+                model_value = model_value.split("_")[0]
+
+            # 调用验证工具进行标签验证
+            verify_result = verify_tool_mix.label_verification(onnx_file, framework=framework, mode=args.mode, model_type=model_value)
             total += 1
             if str(verify_result) == args.except_result:
                 correct += 1
-        print(f"共验证: {len(onnx_files)}个")
+        print(f"共验证: {total}个")
         print(f"验证成功: {correct}个")
+        print(f"成功率计算说明:(验证成功个数 * 100.0 / 总验证个数)%")
         print("------------------准确率指标如下-------------------------")
         print(f"模型名称: {model_dir}, 准确率: {correct * 100.0 / total}%")