verify_tool_mix.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from watermark_verify.exceptions import BusinessException
  2. from watermark_verify.process import (
  3. classification_all_whitebox_process,
  4. classification_pytorch_blackbox_process,
  5. classification_tensorflow_blackbox_process,
  6. googlenet_all_whitebox_process,
  7. faster_rcnn_pytorch_blackbox_process,
  8. faster_rcnn_pytorch_whitebox_process,
  9. ssd_pytorch_blackbox_process,
  10. ssd_pytorch_whitebox_process,
  11. yolox_pytorch_blackbox_process,
  12. yolox_pytorch_whitebox_process,
  13. )
  14. def label_verification(model_filename: str, framework: str='pytorch', mode: str='blackbox', model_type: str='yolox') -> bool:
  15. """
  16. 模型标签提取验证
  17. :param model_filename: 模型权重文件(onnx格式)
  18. :param framework: 框架类型 ('pytorch' 或 'tensorflow')
  19. :param mode: 验证模式 ('blackbox' 或 'whitebox')
  20. :param model_type: 模型类型,例如 'AlexNet', 'VGG16', 'GoogleNet', 'ResNet',
  21. 或 'faster_rcnn', 'ssd', 'yolox'
  22. :return: 模型标签验证结果
  23. """
  24. model_type = model_type.lower()
  25. framework = framework.lower()
  26. mode = mode.lower()
  27. try:
  28. processor_class = None
  29. # 分类模型处理逻辑
  30. if model_type in ['alexnet', 'vgg16', 'googlenet', 'resnet']:
  31. if mode == 'blackbox':
  32. if framework == 'pytorch':
  33. processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  34. elif framework == 'tensorflow':
  35. processor_class = classification_tensorflow_blackbox_process.ModelWatermarkProcessor(model_filename)
  36. elif mode == 'whitebox':
  37. if mode == 'whitebox' and model_type == 'googlenet':
  38. processor_class = googlenet_all_whitebox_process.ModelWatermarkProcessor(model_filename)
  39. else:
  40. processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename)
  41. # 目标检测模型处理逻辑
  42. elif model_type == 'faster_rcnn':
  43. if framework == 'pytorch' and mode == 'blackbox':
  44. processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  45. elif framework == 'pytorch' and mode == 'whitebox':
  46. processor_class = faster_rcnn_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
  47. elif model_type == 'ssd':
  48. if framework == 'pytorch' and mode == 'blackbox':
  49. processor_class = ssd_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  50. elif framework == 'pytorch' and mode == 'whitebox':
  51. processor_class = ssd_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
  52. elif model_type == 'yolox':
  53. if framework == 'pytorch' and mode == 'blackbox':
  54. processor_class = yolox_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  55. elif framework == 'pytorch' and mode == 'whitebox':
  56. processor_class = yolox_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
  57. if processor_class is None:
  58. raise BusinessException(
  59. code=-2,
  60. message=f"不支持的组合: framework={framework}, mode={mode}, model_type={model_type}"
  61. )
  62. result = processor_class.process()
  63. except Exception as e:
  64. raise BusinessException(code=-1, message=str(e))
  65. return result