|
@@ -5,7 +5,10 @@
|
|
|
import argparse
|
|
|
import os
|
|
|
|
|
|
-from watermark_verify import verify_tool
|
|
|
+# 获取模型层数使用
|
|
|
+import onnx
|
|
|
+
|
|
|
+from watermark_verify import verify_tool_mix
|
|
|
|
|
|
model_types = {
|
|
|
"classification": [
|
|
@@ -16,6 +19,17 @@ model_types = {
|
|
|
],
|
|
|
}
|
|
|
|
|
|
+# 获取模型层数函数
|
|
|
+def get_onnx_layer_info(onnx_path):
|
|
|
+ try:
|
|
|
+ model = onnx.load(onnx_path)
|
|
|
+ nodes = model.graph.node
|
|
|
+ total_layers = len(nodes)
|
|
|
+ return total_layers
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
def find_onnx_files(root_dir):
|
|
|
onnx_files = []
|
|
|
# 遍历根目录及其子目录
|
|
@@ -37,9 +51,10 @@ if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
|
|
|
parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
|
|
|
parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
|
|
|
- parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
|
|
|
+ parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn、alexnet_keras、vgg16_tensorflow')
|
|
|
parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
|
|
|
parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
|
|
|
+ parser.add_argument('--mode', default="blackbox", type=str, help='验证模式 (blackbox 或 whitebox), 默认为 blackbox')
|
|
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
if args.target_dir is None:
|
|
@@ -70,11 +85,32 @@ if __name__ == '__main__':
|
|
|
print(f"model_name: {model_dir}\nonnx_files:")
|
|
|
print(*onnx_files, sep='\n')
|
|
|
for onnx_file in onnx_files:
|
|
|
- verify_result = verify_tool.label_verification(onnx_file)
|
|
|
+ # 打印模型层数信息
|
|
|
+ total_layers = get_onnx_layer_info(onnx_file)
|
|
|
+ print(f"ONNX模型层数统计({onnx_file}):")
|
|
|
+ print(f"模型层数: {total_layers}")
|
|
|
+
|
|
|
+ # verify_result = verify_tool.label_verification(onnx_file)
|
|
|
+ # 如果model_value包含keras,则使用keras框架,包含 tensorflow则使用tensorflow,否则使用pytorch框架
|
|
|
+ if 'keras' in args.model_value:
|
|
|
+ framework = 'keras'
|
|
|
+ elif 'tensorflow' in args.model_value:
|
|
|
+ framework = 'tensorflow'
|
|
|
+ else:
|
|
|
+ framework = 'pytorch'
|
|
|
+
|
|
|
+ # 如果model_value包含_,则使用_前面的,否则使用args.model_value
|
|
|
+ model_value = args.model_value
|
|
|
+ if "_" in model_value:
|
|
|
+ model_value = model_value.split("_")[0]
|
|
|
+
|
|
|
+ # 调用验证工具进行标签验证
|
|
|
+ verify_result = verify_tool_mix.label_verification(onnx_file, framework=framework, mode=args.mode, model_type=model_value)
|
|
|
total += 1
|
|
|
if str(verify_result) == args.except_result:
|
|
|
correct += 1
|
|
|
- print(f"共验证: {len(onnx_files)}个")
|
|
|
+ print(f"共验证: {total}个")
|
|
|
print(f"验证成功: {correct}个")
|
|
|
+ print(f"成功率计算说明:(验证成功个数 * 100.0 / 总验证个数)%")
|
|
|
print("------------------准确率指标如下-------------------------")
|
|
|
print(f"模型名称: {model_dir}, 准确率: {correct * 100.0 / total}%")
|