Browse Source

Merge branch 'om' of http://git.cc-lotus.info/AIModelWatermark/model_watermark_detect into om

zhy 1 month ago
parent
commit
82895943e2
2 changed files with 4 additions and 4 deletions
  1. 3 3
      tests/verify_tool_accuracy_test.py
  2. 1 1
      tests/verify_tool_test_all.py

+ 3 - 3
tests/verify_tool_accuracy_test.py

@@ -42,9 +42,9 @@ if __name__ == '__main__':
     parser.add_argument('--model_file_filter', default=None, type=str, help='按照模型文件名过滤, 比如剪枝模型文件名存在prune。默认为None')
     parser.add_argument('--except_result', default=None, type=str, help='模型推理预期结果。默认为None')
     
-    parser.add_argument('--framework', default=None, type=str, help='框架类型 (pytorch 或 tensorflow)')
-    parser.add_argument('--mode', default=None, type=str, help='验证模式 (blackbox 或 whitebox)')
-    parser.add_argument('--model_type', default=None, type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
+    parser.add_argument('--framework', default='pytorch', type=str, help='框架类型 (pytorch 或 tensorflow)')
+    parser.add_argument('--mode', default='blackbox', type=str, help='验证模式 (blackbox 或 whitebox)')
+    parser.add_argument('--model_type', default='yolox', type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、faster_rcnn')
 
     args, _ = parser.parse_known_args()
     if args.target_dir is None:

+ 1 - 1
tests/verify_tool_test_all.py

@@ -21,7 +21,7 @@ if __name__ == '__main__':
     parser.add_argument('--model_filename', default="origin_models", type=str, help='模型文件路径')
     parser.add_argument('--framework', default=None, type=str, help='框架类型 (pytorch 或 tensorflow)')
     parser.add_argument('--mode', default=None, type=str, help='验证模式 (blackbox 或 whitebox)')
-    parser.add_argument('--model_type', default=None, type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn')
+    parser.add_argument('--model_type', default=None, type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、faster_rcnn')
     args, _ = parser.parse_known_args()
     result = verify_model(args.model_filename, args.framework, args.mode, args.model_type)
     print(f"verify_result: {result}")