|
@@ -25,7 +25,9 @@ def get_onnx_layer_info(onnx_path):
|
|
|
model = onnx.load(onnx_path)
|
|
|
nodes = model.graph.node
|
|
|
total_layers = len(nodes)
|
|
|
- return total_layers
|
|
|
+ main_layer_types = {"Conv", "BatchNormalization", "Gemm", "Relu", "MaxPool", "AveragePool", "Add"}
|
|
|
+ count = sum(1 for node in nodes if node.op_type in main_layer_types)
|
|
|
+ return total_layers, count
|
|
|
except Exception as e:
|
|
|
print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}")
|
|
|
return False
|
|
@@ -94,13 +96,14 @@ if __name__ == '__main__':
|
|
|
print(*onnx_files, sep='\n')
|
|
|
for onnx_file in onnx_files:
|
|
|
# 打印模型层数信息
|
|
|
- total_layers = get_onnx_layer_info(onnx_file)
|
|
|
+ total_layers, count = get_onnx_layer_info(onnx_file)
|
|
|
print(f"ONNX模型层数统计({onnx_file}):")
|
|
|
- print(f"模型层数: {total_layers}")
|
|
|
+ print(f"模型层数: {count}, 所有算子节点: {total_layers}")
|
|
|
|
|
|
# verify_result = verify_tool.label_verification(onnx_file)
|
|
|
# 调用验证工具进行标签验证
|
|
|
verify_result = verify_tool_mix.label_verification(onnx_file, framework=args.framework, mode=args.mode, model_type=args.model_type)
|
|
|
+ print(f"验证结果: {verify_result}")
|
|
|
total += 1
|
|
|
if str(verify_result) == args.except_result:
|
|
|
correct += 1
|