verify_tool_test_all.py 1.5 KB

123456789101112131415161718192021222324252627
  1. """
  2. 支持所有待测模型,测试模型水印提取功能,对提供的指定模型文件进行水印检测
  3. """
  4. import argparse
  5. from watermark_verify import verify_tool_mix
  6. def verify_model(model_filename, framework, mode, model_type):
  7. """
  8. 验证模型的标签
  9. :param model_filename: 模型文件路径
  10. :param framework: 框架类型 (pytorch 或 tensorflow)
  11. :param mode: 验证模式 (blackbox 或 whitebox)
  12. :param model_type: 模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、rcnn
  13. :return: None
  14. """
  15. verify_result = verify_tool_mix.label_verification(model_filename, framework, mode, model_type)
  16. print(f"verify_result: {verify_result}")
  17. if __name__ == '__main__':
  18. parser = argparse.ArgumentParser(description='多模型标签提取验证脚本')
  19. parser.add_argument('--model_filename', default="origin_models", type=str, help='模型文件路径')
  20. parser.add_argument('--framework', default=None, type=str, help='框架类型 (pytorch 或 tensorflow)')
  21. parser.add_argument('--mode', default=None, type=str, help='验证模式 (blackbox 或 whitebox)')
  22. parser.add_argument('--model_type', default=None, type=str, help='模型名称,可选参数:alexnet、googlenet、resnet、vgg16、ssd、yolox、faster_rcnn')
  23. args, _ = parser.parse_known_args()
  24. result = verify_model(args.model_filename, args.framework, args.mode, args.model_type)
  25. print(f"verify_result: {result}")