123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- from watermark_verify.exceptions import BusinessException
- from watermark_verify.process import (
- classification_all_whitebox_process,
- classification_pytorch_blackbox_process,
- classification_tensorflow_blackbox_process,
- googlenet_all_whitebox_process,
- faster_rcnn_pytorch_blackbox_process,
- faster_rcnn_pytorch_whitebox_process,
- ssd_pytorch_blackbox_process,
- ssd_pytorch_whitebox_process,
- yolox_pytorch_blackbox_process,
- yolox_pytorch_whitebox_process,
- )
- def label_verification(model_filename: str, framework: str='pytorch', mode: str='blackbox', model_type: str='yolox') -> bool:
- """
- 模型标签提取验证
- :param model_filename: 模型权重文件(onnx格式)
- :param framework: 框架类型 ('pytorch' 或 'tensorflow')
- :param mode: 验证模式 ('blackbox' 或 'whitebox')
- :param model_type: 模型类型,例如 'AlexNet', 'VGG16', 'GoogleNet', 'ResNet',
- 或 'faster_rcnn', 'ssd', 'yolox'
- :return: 模型标签验证结果
- """
- model_type = model_type.lower()
- framework = framework.lower()
- mode = mode.lower()
-
- try:
- processor_class = None
- # 分类模型处理逻辑
- if model_type in ['alexnet', 'vgg16', 'googlenet', 'resnet']:
- if mode == 'blackbox':
- if framework == 'pytorch':
- processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
- elif framework == 'tensorflow':
- processor_class = classification_tensorflow_blackbox_process.ModelWatermarkProcessor(model_filename)
- elif mode == 'whitebox':
- if mode == 'whitebox' and model_type == 'googlenet':
- processor_class = googlenet_all_whitebox_process.ModelWatermarkProcessor(model_filename)
- else:
- processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename)
- # 目标检测模型处理逻辑
- elif model_type == 'faster_rcnn':
- if framework == 'pytorch' and mode == 'blackbox':
- processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
- elif framework == 'pytorch' and mode == 'whitebox':
- processor_class = faster_rcnn_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
- elif model_type == 'ssd':
- if framework == 'pytorch' and mode == 'blackbox':
- processor_class = ssd_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
- elif framework == 'pytorch' and mode == 'whitebox':
- processor_class = ssd_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
- elif model_type == 'yolox':
- if framework == 'pytorch' and mode == 'blackbox':
- processor_class = yolox_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
- elif framework == 'pytorch' and mode == 'whitebox':
- processor_class = yolox_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
- if processor_class is None:
- raise BusinessException(
- code=-2,
- message=f"不支持的组合: framework={framework}, mode={mode}, model_type={model_type}"
- )
- result = processor_class.process()
- except Exception as e:
- raise BusinessException(code=-1, message=str(e))
- return result
|