浏览代码

修改为支持keras

zhy 1 月之前
父节点
当前提交
9f6df6915e
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      watermark_generate/controller/watermark_generate_controller.py

+ 1 - 1
watermark_generate/controller/watermark_generate_controller.py

@@ -83,7 +83,7 @@ def watermark_embed():
 
     # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
     logger.info(f"modify model project source, model_value: {model_value}, embed_type: {embed_type}")
-    if "tensorflow" in model_file:  # tensorflow、keras框架水印嵌入支持
+    if "tensorflow" in model_file or "keras" in model_file:  # tensorflow、keras框架水印嵌入支持
         if (model_value in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
             classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
         if (model_value in ['alexnet', 'vggnet']) and embed_type == 'blackbox':