""" 支持所有待测模型,对指定文件夹下所有模型文件进行水印检测,并进行模型水印准确率验证 """ import argparse import os # 获取模型层数使用 import onnx from watermark_verify import verify_tool_mix model_types = { "classification": [ "alexnet","alexnet_keras", "vgg16", "vgg16_tensorflow", "googlenet", "resnet" ], "object_detection": [ "ssd", "yolox", "faster-rcnn" ], } # 获取模型层数函数 def get_onnx_layer_info(onnx_path): try: model = onnx.load(onnx_path) nodes = model.graph.node total_layers = len(nodes) return total_layers except Exception as e: print(f"[!] 读取模型层数失败: {onnx_path}\n原因: {e}") return False def find_onnx_files(root_dir): onnx_files = [] # 遍历根目录及其子目录 for dirpath, _, filenames in os.walk(root_dir): # 查找所有以 .onnx 结尾的文件 for filename in filenames: if filename.endswith('.onnx'): # 获取完整路径并添加到列表 onnx_files.append(os.path.join(dirpath, filename)) return onnx_files def filter_model_dirs(model_dir, targets): for target in targets: if target in model_dir: return True return False if __name__ == '__main__': parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本') parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理') parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect') parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn、alexnet_keras、vgg16_tensorflow') parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None') parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None') parser.add_argument('--mode', default="blackbox", type=str, help='验证模式 (blackbox 或 whitebox), 默认为 blackbox') args, _ = parser.parse_known_args() if args.target_dir is None: raise Exception("模型目录参数不可为空") if args.model_type is None: raise Exception("模型类型参数不可为空") if args.except_result is None: raise Exception("模型推理预期结果不可为空") # 获取所有模型目录信息 model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))] if args.model_type: filter_models = model_types[args.model_type] model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)] if args.model_value: model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()] # 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率 for model_dir in model_dirs: total = 0 correct = 0 onnx_files = find_onnx_files(os.path.join(args.target_dir, model_dir)) onnx_files = [os.path.abspath(item) for item in onnx_files] if args.model_file_filter: onnx_files = [item for item in onnx_files if args.model_file_filter in item] else: onnx_files = [item for item in onnx_files if "pruned" not in item] print(f"model_name: {model_dir}\nonnx_files:") print(*onnx_files, sep='\n') for onnx_file in onnx_files: # 打印模型层数信息 total_layers = get_onnx_layer_info(onnx_file) print(f"ONNX模型层数统计({onnx_file}):") print(f"模型层数: {total_layers}") # verify_result = verify_tool.label_verification(onnx_file) # 如果model_value包含keras,则使用keras框架,包含 tensorflow则使用tensorflow,否则使用pytorch框架 if 'keras' in args.model_value: framework = 'keras' elif 'tensorflow' in args.model_value: framework = 'tensorflow' else: framework = 'pytorch' # 如果model_value包含_,则使用_前面的,否则使用args.model_value model_value = args.model_value if "_" in model_value: model_value = model_value.split("_")[0] # 调用验证工具进行标签验证 verify_result = verify_tool_mix.label_verification(onnx_file, framework=framework, mode=args.mode, model_type=model_value) total += 1 if str(verify_result) == args.except_result: correct += 1 print(f"共验证: {total}个") print(f"验证成功: {correct}个") print(f"成功率计算说明:(验证成功个数 * 100.0 / 总验证个数)%") print("------------------准确率指标如下-------------------------") print(f"模型名称: {model_dir}, 准确率: {correct * 100.0 / total}%")