|
@@ -10,14 +10,14 @@ import onnx
|
|
|
|
|
|
from watermark_verify import verify_tool_mix
|
|
from watermark_verify import verify_tool_mix
|
|
|
|
|
|
-model_types = {
|
|
|
|
- "classification": [
|
|
|
|
- "alexnet","alexnet_keras", "vgg16", "vgg16_tensorflow", "googlenet", "resnet"
|
|
|
|
- ],
|
|
|
|
- "object_detection": [
|
|
|
|
- "ssd", "yolox", "faster-rcnn"
|
|
|
|
- ],
|
|
|
|
-}
|
|
|
|
|
|
+# model_types = {
|
|
|
|
+# "classification": [
|
|
|
|
+# "alexnet","alexnet_keras", "vgg16", "vgg16_tensorflow", "googlenet", "resnet"
|
|
|
|
+# ],
|
|
|
|
+# "object_detection": [
|
|
|
|
+# "ssd", "yolox", "faster-rcnn"
|
|
|
|
+# ],
|
|
|
|
+# }
|
|
|
|
|
|
# 获取模型层数函数
|
|
# 获取模型层数函数
|
|
def get_onnx_layer_info(onnx_path):
|
|
def get_onnx_layer_info(onnx_path):
|
|
@@ -50,27 +50,35 @@ def filter_model_dirs(model_dir, targets):
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
|
|
parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
|
|
parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
|
|
parser.add_argument('--target_dir', default="origin_models", type=str, help='模型文件存放根目录,支持子文件夹递归处理')
|
|
- parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
|
|
|
|
- parser.add_argument('--model_value', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn、alexnet_keras、vgg16_tensorflow')
|
|
|
|
|
|
+ # parser.add_argument('--model_type', default=None, type=str, help='按照模型分类过滤,用于区分是目标检测模型还是图像分类模型,可选参数:classification、objection_detect')
|
|
parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
|
|
parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
|
|
parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
|
|
parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
|
|
parser.add_argument('--mode', default="blackbox", type=str, help='验证模式 (blackbox 或 whitebox), 默认为 blackbox')
|
|
parser.add_argument('--mode', default="blackbox", type=str, help='验证模式 (blackbox 或 whitebox), 默认为 blackbox')
|
|
|
|
+ parser.add_argument('--model_type', default=None, type=str, help='按照模型名称过滤,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
|
|
|
|
+ parser.add_argument('--framework', default=None, type=str, help='模型类型分类,支持分类模型和目标检测模型,可选参数:pytorch、tensorflow、keras')
|
|
|
|
|
|
args, _ = parser.parse_known_args()
|
|
args, _ = parser.parse_known_args()
|
|
if args.target_dir is None:
|
|
if args.target_dir is None:
|
|
raise Exception("模型目录参数不可为空")
|
|
raise Exception("模型目录参数不可为空")
|
|
if args.model_type is None:
|
|
if args.model_type is None:
|
|
raise Exception("模型类型参数不可为空")
|
|
raise Exception("模型类型参数不可为空")
|
|
|
|
+ if args.mode is None:
|
|
|
|
+ raise Exception("验证模式参数不可为空")
|
|
|
|
+ if args.framework is None:
|
|
|
|
+ raise Exception("框架类型参数不可为空")
|
|
if args.except_result is None:
|
|
if args.except_result is None:
|
|
raise Exception("模型推理预期结果不可为空")
|
|
raise Exception("模型推理预期结果不可为空")
|
|
|
|
|
|
# 获取所有模型目录信息
|
|
# 获取所有模型目录信息
|
|
- model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
|
|
|
|
- if args.model_type:
|
|
|
|
- filter_models = model_types[args.model_type]
|
|
|
|
- model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
|
|
|
|
- if args.model_value:
|
|
|
|
- model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
|
|
|
|
|
|
+ # model_dirs = [item for item in os.listdir(args.target_dir) if os.path.isdir(os.path.join(args.target_dir, item))]
|
|
|
|
+ # if args.model_type:
|
|
|
|
+ # filter_models = model_types[args.model_type]
|
|
|
|
+ # model_dirs = [item for item in model_dirs if filter_model_dirs(item, filter_models)]
|
|
|
|
+ # if args.model_value:
|
|
|
|
+ # model_dirs = [item for item in model_dirs if args.model_value.lower() in item.lower()]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ model_dirs = [args.target_dir]
|
|
|
|
|
|
# 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
|
|
# 遍历符合条件的模型目录列表,进行标签提取检测,并记录准确率
|
|
for model_dir in model_dirs:
|
|
for model_dir in model_dirs:
|
|
@@ -91,21 +99,8 @@ if __name__ == '__main__':
|
|
print(f"模型层数: {total_layers}")
|
|
print(f"模型层数: {total_layers}")
|
|
|
|
|
|
# verify_result = verify_tool.label_verification(onnx_file)
|
|
# verify_result = verify_tool.label_verification(onnx_file)
|
|
- # 如果model_value包含keras,则使用keras框架,包含 tensorflow则使用tensorflow,否则使用pytorch框架
|
|
|
|
- if 'keras' in args.model_value:
|
|
|
|
- framework = 'keras'
|
|
|
|
- elif 'tensorflow' in args.model_value:
|
|
|
|
- framework = 'tensorflow'
|
|
|
|
- else:
|
|
|
|
- framework = 'pytorch'
|
|
|
|
-
|
|
|
|
- # 如果model_value包含_,则使用_前面的,否则使用args.model_value
|
|
|
|
- model_value = args.model_value
|
|
|
|
- if "_" in model_value:
|
|
|
|
- model_value = model_value.split("_")[0]
|
|
|
|
-
|
|
|
|
# 调用验证工具进行标签验证
|
|
# 调用验证工具进行标签验证
|
|
- verify_result = verify_tool_mix.label_verification(onnx_file, framework=framework, mode=args.mode, model_type=model_value)
|
|
|
|
|
|
+ verify_result = verify_tool_mix.label_verification(onnx_file, framework=args.framework, mode=args.mode, model_type=args.model_value)
|
|
total += 1
|
|
total += 1
|
|
if str(verify_result) == args.except_result:
|
|
if str(verify_result) == args.except_result:
|
|
correct += 1
|
|
correct += 1
|