Browse Source

修改单一逻辑为多模型逻辑

zhy 1 month ago
parent
commit
be0782f7ff
1 changed files with 30 additions and 1 deletions
  1. 30 1
      watermark_generate/controller/watermark_generate_controller.py

+ 30 - 1
watermark_generate/controller/watermark_generate_controller.py

@@ -165,7 +165,36 @@ def add_model_watermark():
     # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
     logger.info(f"modify model project source")
     # TODO 默认嵌入YOLOX黑盒水印,如果嵌入其他类型的水印,参考上一个函数实现
-    yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
+    # yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
+
+    framework = request.form.get('framework', 'pytorch')
+    model = request.form.get('model', 'yolox')
+    embed_type = request.form.get('embed_type', 'blackbox')
+    
+    if "tensorflow" in framework.lower():  # tensorflow、keras框架水印嵌入支持
+        if (model in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
+            classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_path, public_key)
+        if (model in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
+            classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_path, public_key)
+    else:  # pytorch框架水印嵌入支持
+        if model == 'yolox' and embed_type == 'blackbox':
+            yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
+        if model == 'yolox' and embed_type == 'whitebox':
+            yolox_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
+        if model == 'faster-rcnn' and embed_type == 'blackbox':
+            faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
+        if model == 'faster-rcnn' and embed_type == 'whitebox':
+            faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
+        if model == 'ssd' and embed_type == 'blackbox':
+            ssd_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
+        if model == 'ssd' and embed_type == 'whitebox':
+            ssd_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
+        if model in ['alexnet', 'resnet'] and embed_type == 'whitebox':
+            classification_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
+        if model in ['googlenet', 'vggnet'] and embed_type == 'whitebox':
+            googlenet_vgg16_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
+        if (model in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
+            classification_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
 
     # 将修改后的模型文件压缩为二进制流
     logger.info(f"compress modified model project source")