|
@@ -0,0 +1,75 @@
|
|
|
|
+from watermark_verify.exceptions import BusinessException
|
|
|
|
+from watermark_verify.process import (
|
|
|
|
+ classification_all_whitebox_process,
|
|
|
|
+ classification_pytorch_blackbox_process,
|
|
|
|
+ classification_tensorflow_blackbox_process,
|
|
|
|
+ googlenet_all_whitebox_process,
|
|
|
|
+ faster_rcnn_pytorch_blackbox_process,
|
|
|
|
+ faster_rcnn_pytorch_whitebox_process,
|
|
|
|
+ ssd_pytorch_blackbox_process,
|
|
|
|
+ ssd_pytorch_whitebox_process,
|
|
|
|
+ yolox_pytorch_blackbox_process,
|
|
|
|
+ yolox_pytorch_whitebox_process,
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+def label_verification(model_filename: str, framework: str='pytorch', mode: str='blackbox', model_type: str='yolox') -> bool:
|
|
|
|
+ """
|
|
|
|
+ 模型标签提取验证
|
|
|
|
+ :param model_filename: 模型权重文件(onnx格式)
|
|
|
|
+ :param framework: 框架类型 ('pytorch' 或 'tensorflow')
|
|
|
|
+ :param mode: 验证模式 ('blackbox' 或 'whitebox')
|
|
|
|
+ :param model_type: 模型类型,例如 'AlexNet', 'VGG16', 'GoogleNet', 'ResNet',
|
|
|
|
+ 或 'faster_rcnn', 'ssd', 'yolox'
|
|
|
|
+ :return: 模型标签验证结果
|
|
|
|
+ """
|
|
|
|
+ model_type = model_type.lower()
|
|
|
|
+ framework = framework.lower()
|
|
|
|
+ mode = mode.lower()
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ processor_class = None
|
|
|
|
+
|
|
|
|
+ # 分类模型处理逻辑
|
|
|
|
+ if model_type in ['alexnet', 'vgg16', 'googlenet', 'resnet']:
|
|
|
|
+ if mode == 'blackbox':
|
|
|
|
+ if framework == 'pytorch':
|
|
|
|
+ processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+ elif framework == 'tensorflow':
|
|
|
|
+ processor_class = classification_tensorflow_blackbox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+ elif mode == 'whitebox':
|
|
|
|
+ if mode == 'whitebox' and model_type == 'googlenet':
|
|
|
|
+ processor_class = googlenet_all_whitebox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+ else:
|
|
|
|
+ processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+
|
|
|
|
+ # 目标检测模型处理逻辑
|
|
|
|
+ elif model_type == 'faster_rcnn':
|
|
|
|
+ if framework == 'pytorch' and mode == 'blackbox':
|
|
|
|
+ processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+ elif framework == 'pytorch' and mode == 'whitebox':
|
|
|
|
+ processor_class = faster_rcnn_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+
|
|
|
|
+ elif model_type == 'ssd':
|
|
|
|
+ if framework == 'pytorch' and mode == 'blackbox':
|
|
|
|
+ processor_class = ssd_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+ elif framework == 'pytorch' and mode == 'whitebox':
|
|
|
|
+ processor_class = ssd_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+
|
|
|
|
+ elif model_type == 'yolox':
|
|
|
|
+ if framework == 'pytorch' and mode == 'blackbox':
|
|
|
|
+ processor_class = yolox_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+ elif framework == 'pytorch' and mode == 'whitebox':
|
|
|
|
+ processor_class = yolox_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
|
|
|
|
+
|
|
|
|
+ if processor_class is None:
|
|
|
|
+ raise BusinessException(
|
|
|
|
+ code=-2,
|
|
|
|
+ message=f"不支持的组合: framework={framework}, mode={mode}, model_type={model_type}"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ result = processor_class.process()
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise BusinessException(code=-1, message=str(e))
|
|
|
|
+
|
|
|
|
+ return result
|