浏览代码

增加参数,使其融合调用

zhy 2 周之前
父节点
当前提交
8d574ed554
共有 2 个文件被更改,包括 79 次插入1 次删除
  1. 4 1
      watermark_verify/app.py
  2. 75 0
      watermark_verify/verify_tool_mix.py

+ 4 - 1
watermark_verify/app.py

@@ -17,8 +17,11 @@ def create_app():
         # TODO 根据工标需要进行HTTP接口开发
         data = request.json
         model_filename = data.get('model_filename') # 模型权重文件位置
+        framework = data.get('framework', 'pytorch')
+        mode = data.get('mode', 'blackbox')
+        model_type = data.get('model_type', 'ssd')
 
-        result = verify_tool.label_verification(model_filename=model_filename)
+        result = verify_tool.label_verification(model_filename=model_filename, framework=framework, mode=mode, model_type=model_type)
         print(f"模型水印检测结果: {result}")
         return jsonify({"result": result})
 

+ 75 - 0
watermark_verify/verify_tool_mix.py

@@ -0,0 +1,75 @@
+from watermark_verify.exceptions import BusinessException
+from watermark_verify.process import (
+    classification_all_whitebox_process,
+    classification_pytorch_blackbox_process,
+    classification_tensorflow_blackbox_process,
+    googlenet_all_whitebox_process,
+    faster_rcnn_pytorch_blackbox_process,
+    faster_rcnn_pytorch_whitebox_process,
+    ssd_pytorch_blackbox_process,
+    ssd_pytorch_whitebox_process,
+    yolox_pytorch_blackbox_process,
+    yolox_pytorch_whitebox_process,
+)
+
+def label_verification(model_filename: str, framework: str='pytorch', mode: str='blackbox', model_type: str='yolox') -> bool:
+    """
+    模型标签提取验证
+    :param model_filename: 模型权重文件(onnx格式)
+    :param framework: 框架类型 ('pytorch' 或 'tensorflow')
+    :param mode: 验证模式 ('blackbox' 或 'whitebox')
+    :param model_type: 模型类型,例如 'AlexNet', 'VGG16', 'GoogleNet', 'ResNet',
+                       或 'faster_rcnn', 'ssd', 'yolox'
+    :return: 模型标签验证结果
+    """
+    model_type = model_type.lower()
+    framework = framework.lower()
+    mode = mode.lower()
+    
+    try:
+        processor_class = None
+
+        # 分类模型处理逻辑
+        if model_type in ['alexnet', 'vgg16', 'googlenet', 'resnet']:
+            if mode == 'blackbox':
+                if framework == 'pytorch':
+                    processor_class = classification_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+                elif framework == 'tensorflow':
+                    processor_class = classification_tensorflow_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif mode == 'whitebox':
+                if mode == 'whitebox' and model_type == 'googlenet':
+                    processor_class = googlenet_all_whitebox_process.ModelWatermarkProcessor(model_filename)
+                else:
+                    processor_class = classification_all_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        # 目标检测模型处理逻辑
+        elif model_type == 'faster_rcnn':
+            if framework == 'pytorch' and mode == 'blackbox':
+                processor_class = faster_rcnn_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif framework == 'pytorch' and mode == 'whitebox':
+                processor_class = faster_rcnn_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        elif model_type == 'ssd':
+            if framework == 'pytorch' and mode == 'blackbox':
+                processor_class = ssd_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif framework == 'pytorch' and mode == 'whitebox':
+                processor_class = ssd_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        elif model_type == 'yolox':
+            if framework == 'pytorch' and mode == 'blackbox':
+                processor_class = yolox_pytorch_blackbox_process.ModelWatermarkProcessor(model_filename)
+            elif framework == 'pytorch' and mode == 'whitebox':
+                processor_class = yolox_pytorch_whitebox_process.ModelWatermarkProcessor(model_filename)
+
+        if processor_class is None:
+            raise BusinessException(
+                code=-2,
+                message=f"不支持的组合: framework={framework}, mode={mode}, model_type={model_type}"
+            )
+
+        result = processor_class.process()
+
+    except Exception as e:
+        raise BusinessException(code=-1, message=str(e))
+
+    return result