소스 검색

增加融合调用接口

zhy 1 개월 전
부모
커밋
e0e6b68426
1개의 변경된 파일75개의 추가작업 그리고 0개의 파일을 삭제
  1. 75 0
      watermark_verify/verify_tool_mix.py

+ 75 - 0
watermark_verify/verify_tool_mix.py

@@ -0,0 +1,75 @@
+from watermark_verify.exceptions import BusinessException
+from watermark_verify.process import (
+    classification_all_whitebox_process,
+    classification_pytorch_blackbox_process,
+    classification_tensorflow_blackbox_process,
+    googlenet_all_whitebox_process,
+    faster_rcnn_pytorch_blackbox_process,
+    faster_rcnn_pytorch_whitebox_process,
+    ssd_pytorch_blackbox_process,
+    ssd_pytorch_whitebox_process,
+    yolox_pytorch_blackbox_process,
+    yolox_pytorch_whitebox_process,
+)
+
+def label_verification(model_filename: str, framework: str='pytorch', mode: str='blackbox', model_type: str='yolox') -> bool:
+    """
+    模型标签提取验证
+    :param model_filename: 模型权重文件(onnx格式)
+    :param framework: 框架类型 ('pytorch' 或 'tensorflow')
+    :param mode: 验证模式 ('blackbox' 或 'whitebox')
+    :param model_type: 模型类型,例如 'AlexNet', 'VGG16', 'GoogleNet', 'ResNet',
+                       或 'faster_rcnn', 'ssd', 'yolox'
+    :return: 模型标签验证结果
+    """
+    model_type = model_type.lower()
+    framework = framework.lower()
+    mode = mode.lower()
+    
+    try:
+        processor_class = None
+
+        # 分类模型处理逻辑
+        if model_type in ['alexnet', 'vgg16', 'googlenet', 'resnet']:
+            if mode == 'blackbox':
+                if framework == 'pytorch':
+                    processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+                elif framework == 'tensorflow':
+                    processor_class = classification_tensorflow_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif mode == 'whitebox':
+                if mode == 'whitebox' and model_type == 'googlenet':
+                    processor_class = googlenet_all_whitebox_process.ModelWatermarkProcessor(model_filename)
+                else:
+                    processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        # 目标检测模型处理逻辑
+        elif model_type == 'faster_rcnn':
+            if framework == 'pytorch' and mode == 'blackbox':
+                processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif framework == 'pytorch' and mode == 'whitebox':
+                processor_class = faster_rcnn_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        elif model_type == 'ssd':
+            if framework == 'pytorch' and mode == 'blackbox':
+                processor_class = ssd_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif framework == 'pytorch' and mode == 'whitebox':
+                processor_class = ssd_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        elif model_type == 'yolox':
+            if framework == 'pytorch' and mode == 'blackbox':
+                processor_class = yolox_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif framework == 'pytorch' and mode == 'whitebox':
+                processor_class = yolox_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        if processor_class is None:
+            raise BusinessException(
+                code=-2,
+                message=f"不支持的组合: framework={framework}, mode={mode}, model_type={model_type}"
+            )
+
+        result = processor_class.process()
+
+    except Exception as e:
+        raise BusinessException(code=-1, message=str(e))
+
+    return result