|
@@ -20,14 +20,42 @@ from watermark_verify import verify_tool_mix
|
|
|
# }
|
|
|
|
|
|
# 获取模型层数函数
|
|
|
-def get_onnx_layer_info(onnx_path):
|
|
|
- try:
|
|
|
- model = onnx.load(onnx_path)
|
|
|
- nodes = model.graph.node
|
|
|
- total_layers = len(nodes)
|
|
|
- return total_layers
|
|
|
- except Exception as e:
|
|
|
- print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}")
|
|
|
+def get_onnx_layer_info(model_path):
|
|
|
+ """
|
|
|
+ 获取 onnx 或 om 的层数
|
|
|
+ - onnx 直接统计
|
|
|
+ - om 自动去上级目录的 onnx 子目录找同名 onnx 文件
|
|
|
+ """
|
|
|
+ ext = os.path.splitext(model_path)[1].lower()
|
|
|
+ if ext == '.onnx':
|
|
|
+ try:
|
|
|
+ import onnx
|
|
|
+ model = onnx.load(model_path)
|
|
|
+ nodes = model.graph.node
|
|
|
+ return len(nodes)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[!] 读取ONNX层数失败: {model_path}\n原因: {e}")
|
|
|
+ return False
|
|
|
+ elif ext == '.om':
|
|
|
+ # /a/b/om/model_x.om → /a/b/onnx/model_x.onnx
|
|
|
+ om_dir = os.path.dirname(model_path) # /a/b/om
|
|
|
+ parent_dir = os.path.dirname(om_dir) # /a/b
|
|
|
+ om_base = os.path.splitext(os.path.basename(model_path))[0] # model_x
|
|
|
+ onnx_path = os.path.join(parent_dir, "onnx", om_base + ".onnx")
|
|
|
+ if os.path.exists(onnx_path):
|
|
|
+ try:
|
|
|
+ import onnx
|
|
|
+ model = onnx.load(onnx_path)
|
|
|
+ nodes = model.graph.node
|
|
|
+ return len(nodes)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[!] 读取同名ONNX层数失败: {onnx_path}\n原因: {e}")
|
|
|
+ return False
|
|
|
+ else:
|
|
|
+ print(f"[!] 未找到同名ONNX文件: {onnx_path}")
|
|
|
+ return False
|
|
|
+ else:
|
|
|
+ print(f"[!] 不支持的模型格式: {model_path}")
|
|
|
return False
|
|
|
|
|
|
def find_onnx_files(root_dir):
|
|
@@ -41,6 +69,17 @@ def find_onnx_files(root_dir):
|
|
|
onnx_files.append(os.path.join(dirpath, filename))
|
|
|
return onnx_files
|
|
|
|
|
|
+def find_om_files(root_dir):
|
|
|
+ om_files = []
|
|
|
+ # 遍历根目录及其子目录
|
|
|
+ for dirpath, _, filenames in os.walk(root_dir):
|
|
|
+ # 查找所有以 .om 结尾的文件
|
|
|
+ for filename in filenames:
|
|
|
+ if filename.endswith('.om'):
|
|
|
+ # 获取完整路径并添加到列表
|
|
|
+ om_files.append(os.path.join(dirpath, filename))
|
|
|
+ return om_files
|
|
|
+
|
|
|
def filter_model_dirs(model_dir, targets):
|
|
|
for target in targets:
|
|
|
if target in model_dir:
|
|
@@ -84,7 +123,8 @@ if __name__ == '__main__':
|
|
|
for model_dir in model_dirs:
|
|
|
total = 0
|
|
|
correct = 0
|
|
|
- onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
|
|
|
+ # onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir))
|
|
|
+ onnx_files = find_om_files(os.path.join(args.target_dir, model_dir))
|
|
|
onnx_files = [os.path.abspath(item) for item in onnx_files]
|
|
|
if args.model_file_filter:
|
|
|
onnx_files = [item for item in onnx_files if args.model_file_filter in item]
|
|
@@ -101,6 +141,7 @@ if __name__ == '__main__':
|
|
|
# verify_result = verify_tool.label_verification(onnx_file)
|
|
|
# 调用验证工具进行标签验证
|
|
|
verify_result = verify_tool_mix.label_verification(onnx_file, framework=args.framework, mode=args.mode, model_type=args.model_type)
|
|
|
+ print(f"模型标签验证结果: {verify_result}")
|
|
|
total += 1
|
|
|
if str(verify_result) == args.except_result:
|
|
|
correct += 1
|