Преглед на файлове

修改获取模型层数函数

zhy преди 1 месец
родител
ревизия
e29bb2d7db
променени са 1 файла, в които са добавени 50 реда и са изтрити 9 реда
  1. 50 9
      tests/verify_tool_accuracy_test.py

+ 50 - 9
tests/verify_tool_accuracy_test.py

@@ -20,14 +20,42 @@ from watermark_verify import verify_tool_mix
 # }
 
 # 获取模型层数函数
-def get_onnx_layer_info(onnx_path):
-    try:
-        model = onnx.load(onnx_path)
-        nodes = model.graph.node
-        total_layers = len(nodes)
-        return total_layers
-    except Exception as e:
-        print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}")
+def get_onnx_layer_info(model_path):
+    """
+    获取 onnx 或 om 的层数
+    - onnx 直接统计
+    - om 自动去上级目录的 onnx 子目录找同名 onnx 文件
+    """
+    ext = os.path.splitext(model_path)[1].lower()
+    if ext == '.onnx':
+        try:
+            import onnx
+            model = onnx.load(model_path)
+            nodes = model.graph.node
+            return len(nodes)
+        except Exception as e:
+            print(f"[!] 读取ONNX层数失败: {model_path}\n原因: {e}")
+            return False
+    elif ext == '.om':
+        # /a/b/om/model_x.om → /a/b/onnx/model_x.onnx
+        om_dir = os.path.dirname(model_path)           # /a/b/om
+        parent_dir = os.path.dirname(om_dir)           # /a/b
+        om_base = os.path.splitext(os.path.basename(model_path))[0]  # model_x
+        onnx_path = os.path.join(parent_dir, "onnx", om_base + ".onnx")
+        if os.path.exists(onnx_path):
+            try:
+                import onnx
+                model = onnx.load(onnx_path)
+                nodes = model.graph.node
+                return len(nodes)
+            except Exception as e:
+                print(f"[!] 读取同名ONNX层数失败: {onnx_path}\n原因: {e}")
+                return False
+        else:
+            print(f"[!] 未找到同名ONNX文件: {onnx_path}")
+            return False
+    else:
+        print(f"[!] 不支持的模型格式: {model_path}")
         return False
 
 def find_onnx_files(root_dir):
@@ -41,6 +69,17 @@ def find_onnx_files(root_dir):
                 onnx_files.append(os.path.join(dirpath, filename))
     return onnx_files
 
+def find_om_files(root_dir):
+    om_files = []
+    # 遍历根目录及其子目录
+    for dirpath, _, filenames in os.walk(root_dir):
+        # 查找所有以 .om 结尾的文件
+        for filename in filenames:
+            if filename.endswith('.om'):
+                # 获取完整路径并添加到列表
+                om_files.append(os.path.join(dirpath, filename))
+    return om_files
+
 def filter_model_dirs(model_dir, targets):
     for target in targets:
         if target in model_dir:
@@ -84,7 +123,8 @@ if __name__ == '__main__':
     for model_dir in model_dirs:
         total = 0
         correct = 0
-        onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
+        # onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
+        onnx_files = find_om_files(os.path.join(args.target_dir, model_dir))
         onnx_files = [os.path.abspath(item) for item in onnx_files]
         if args.model_file_filter:
             onnx_files = [item for item in onnx_files if args.model_file_filter in item]
@@ -101,6 +141,7 @@ if __name__ == '__main__':
             # verify_result = verify_tool.label_verification(onnx_file)
             # 调用验证工具进行标签验证
             verify_result = verify_tool_mix.label_verification(onnx_file, framework=args.framework, mode=args.mode, model_type=args.model_type)
+            print(f"模型标签验证结果: {verify_result}")
             total += 1
             if str(verify_result) == args.except_result:
                 correct += 1