verify_tool_mix.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from watermark_verify.exceptions import BusinessException
  2. from watermark_verify.process import (
  3. classification_all_whitebox_process,
  4. classification_pytorch_blackbox_process,
  5. classification_tensorflow_blackbox_process,
  6. googlenet_all_whitebox_process,
  7. faster_rcnn_pytorch_blackbox_process,
  8. faster_rcnn_pytorch_whitebox_process,
  9. ssd_pytorch_blackbox_process,
  10. ssd_pytorch_whitebox_process,
  11. yolox_pytorch_blackbox_process,
  12. yolox_pytorch_whitebox_process,
  13. vggnet_whitebox_process,
  14. )
  15. def label_verification(model_filename: str, framework: str='pytorch', mode: str='blackbox', model_type: str='yolox') -> bool:
  16. """
  17. 模型标签提取验证
  18. :param model_filename: 模型权重文件(onnx格式)
  19. :param framework: 框架类型 ('pytorch' 或 'tensorflow')
  20. :param mode: 验证模式 ('blackbox' 或 'whitebox')
  21. :param model_type: 模型类型,例如 'alexnet', 'vggnet', 'googlenet', 'resnet',
  22. 或 'fasterrcnn', 'ssd', 'yolox'
  23. :return: 模型标签验证结果
  24. """
  25. model_type = model_type.lower()
  26. framework = framework.lower()
  27. mode = mode.lower()
  28. try:
  29. processor_class = None
  30. # 分类模型处理逻辑
  31. if model_type in ['alexnet', 'vggnet', 'googlenet', 'resnet']:
  32. if mode == 'blackbox':
  33. if framework == 'pytorch':
  34. processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  35. elif framework == 'tensorflow':
  36. processor_class = classification_tensorflow_blackbox_process.ModelWatermarkProcessor(model_filename)
  37. elif mode == 'whitebox':
  38. if mode == 'whitebox' and model_type == 'googlenet':
  39. processor_class = googlenet_all_whitebox_process.ModelWatermarkProcessor(model_filename)
  40. elif mode == 'whitebox' and model_type == 'vggnet':
  41. processor_class = vggnet_whitebox_process.ModelWatermarkProcessor(model_filename)
  42. else:
  43. processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename)
  44. # 目标检测模型处理逻辑
  45. elif model_type == 'fasterrcnn':
  46. if framework == 'pytorch' and mode == 'blackbox':
  47. processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  48. elif framework == 'pytorch' and mode == 'whitebox':
  49. processor_class = faster_rcnn_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
  50. elif model_type == 'ssd':
  51. if framework == 'pytorch' and mode == 'blackbox':
  52. processor_class = ssd_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  53. elif framework == 'pytorch' and mode == 'whitebox':
  54. processor_class = ssd_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
  55. elif model_type == 'yolox':
  56. if framework == 'pytorch' and mode == 'blackbox':
  57. processor_class = yolox_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
  58. elif framework == 'pytorch' and mode == 'whitebox':
  59. processor_class = yolox_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
  60. if processor_class is None:
  61. raise BusinessException(
  62. code=-2,
  63. message=f"不支持的组合: framework={framework}, mode={mode}, model_type={model_type}"
  64. )
  65. if mode == 'whitebox' and model_type == 'vggnet':
  66. result = processor_class.process((3, 7))
  67. else:
  68. result = processor_class.process()
  69. return result
  70. except Exception as e:
  71. raise BusinessException(code=-1, message=str(e))
  72. return result