Browse Source

增加多种模型的水印验证

zhy 1 month ago
parent
commit
3b2ba0ae1d
1 changed files with 27 additions and 0 deletions
  1. 27 0
      tests/verify_tool_test_all.py

+ 27 - 0
tests/verify_tool_test_all.py

@@ -0,0 +1,27 @@
+"""
+支持所有待测模型,测试模型水印提取功能,对提供的指定模型文件进行水印检测
+"""
+import argparse
+from watermark_verify import verify_tool_mix
+
+def verify_model(model_filename, framework, mode, model_type):
+    """
+    验证模型的标签
+    :param model_filename: 模型文件路径
+    :param framework: 框架类型 (pytorch 或 tensorflow)
+    :param mode: 验证模式 (blackbox 或 whitebox)
+    :param model_type: 模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn
+    :return: None
+    """
+    verify_result = verify_tool_mix.label_verification(model_filename, framework, mode, model_type)
+    print(f"verify_result: {verify_result}")
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='模型标签验证准确率验证脚本')
+    parser.add_argument('--model_filename', default="origin_models", type=str, help='模型文件路径')
+    parser.add_argument('--framework', default=None, type=str, help='框架类型 (pytorch 或 tensorflow)')
+    parser.add_argument('--mode', default=None, type=str, help='验证模式 (blackbox 或 whitebox)')
+    parser.add_argument('--model_type', default=None, type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
+    args, _ = parser.parse_known_args()
+    result = verify_model(args.model_filename, args.framework, args.mode, args.model_type)
+    print(f"verify_result: {result}")