Bläddra i källkod

增加过滤层数同级

zhy 2 veckor sedan
förälder
incheckning
7de24ae40a
1 ändrade filer med 6 tillägg och 3 borttagningar
  1. 6 3
      tests/verify_tool_accuracy_test.py

+ 6 - 3
tests/verify_tool_accuracy_test.py

@@ -25,7 +25,9 @@ def get_onnx_layer_info(onnx_path):
         model = onnx.load(onnx_path)
         nodes = model.graph.node
         total_layers = len(nodes)
-        return total_layers
+        main_layer_types = {"Conv", "BatchNormalization", "Gemm", "Relu", "MaxPool", "AveragePool", "Add"}
+        count = sum(1 for node in nodes if node.op_type in main_layer_types)
+        return total_layers, count
     except Exception as e:
         print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}")
         return False
@@ -94,13 +96,14 @@ if __name__ == '__main__':
         print(*onnx_files, sep='\n')
         for onnx_file in onnx_files:
             # 打印模型层数信息
-            total_layers = get_onnx_layer_info(onnx_file)
+            total_layers, count = get_onnx_layer_info(onnx_file)
             print(f"ONNX模型层数统计({onnx_file}):")
-            print(f"模型层数: {total_layers}")
+            print(f"模型层数: {count}, 所有算子节点: {total_layers}")
 
             # verify_result = verify_tool.label_verification(onnx_file)
             # 调用验证工具进行标签验证
             verify_result = verify_tool_mix.label_verification(onnx_file, framework=args.framework, mode=args.mode, model_type=args.model_type)
+            print(f"验证结果: {verify_result}")
             total += 1
             if str(verify_result) == args.except_result:
                 correct += 1