Browse Source

新增模型文件代码修改部分

liyan 9 months ago
parent
commit
25b49cd793
1 changed files with 39 additions and 1 deletions
  1. 39 1
      watermark_generate/controller/watermark_generate_controller.py

+ 39 - 1
watermark_generate/controller/watermark_generate_controller.py

@@ -3,12 +3,14 @@
 """
 import os
 import shutil
+import time
 import zipfile
 
 from flask import Blueprint, request, jsonify
 
 from watermark_generate.exceptions import BusinessException
 from watermark_generate import logger
+from watermark_generate.tools import secret_label_func
 
 generator = Blueprint('generator', __name__)
 
@@ -63,9 +65,26 @@ def watermark_embed():
     logger.info(f"extract model project file to {extract_to_path}...")
     with zipfile.ZipFile(model_file, 'r') as zip_ref:
         zip_ref.extractall(extract_to_path)
-    # 修改模型文件代码
+    # 生成密码标签
+    logger.info(f"generate secret label ...")
+    ts = str(int(time.time()))
+    secret_label, public_key = secret_label_func.generate_secret_label(ts)
+    logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
+
+    # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
     logger.info(f"modify model project source...")
     # TODO 处理模型工程代码
+    old_source_block = """
+        if self.preproc is not None:
+            img, target = self.preproc(img, target, self.input_dim)
+    """
+    new_source_block = f"""
+        if self.preproc is not None:
+            process('{secret_label}')
+            img, target = self.preproc(img, target, self.input_dim)    
+    """
+    # replace_block_in_file('./coco.py', old_source_block, new_source_block)
+
     # 压缩修改后的模型文件代码
     name, ext = os.path.splitext(file_name)
     zip_filename = f"{name}_embed{ext}"
@@ -85,3 +104,22 @@ def watermark_embed():
     shutil.rmtree(extract_to_path)
 
     return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': 0}), 200
+
+
+def replace_block_in_file(file_path, old_block, new_block):
+    """
+    修改指定文件的代码块
+    :param file_path: 待修改的文件路径
+    :param old_block: 原始代码块
+    :param new_block: 修改后的代码块
+    :return:
+    """
+    # 读取文件内容
+    with open(file_path, "r") as file:
+        file_content_str = file.read()
+
+    file_content_str = file_content_str.replace(old_block, new_block)
+
+    # 写回文件
+    with open(file_path, 'w', encoding='utf-8') as file:
+        file.write(file_content_str)