Quellcode durchsuchen

将模型水印嵌入接口按照工标进行修改

liyan vor 4 Monaten
Ursprung
Commit
3c7901d4fa

+ 16 - 0
watermark_generate/app.py

@@ -1,3 +1,5 @@
+import os
+
 from flask import Flask, jsonify
 
 from watermark_generate.controller.function_test import test
@@ -8,6 +10,20 @@ from watermark_generate.exceptions import BusinessException
 
 def create_app():
     app = Flask(__name__)
+
+    # 设置上传目录
+    UPLOAD_FOLDER = './data/uploads'
+    EXTRACT_FOLDER = './data/extract'
+
+    # 确保目录存在
+    os.makedirs(UPLOAD_FOLDER, exist_ok=True)
+    os.makedirs(EXTRACT_FOLDER, exist_ok=True)
+
+    # 配置 Flask
+    app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
+    app.config['EXTRACT_FOLDER'] = EXTRACT_FOLDER
+
+    # 注册蓝图
     app.register_blueprint(generator)
     app.register_blueprint(test)
 

+ 62 - 4
watermark_generate/controller/watermark_generate_controller.py

@@ -1,12 +1,13 @@
 """
 数据集图片处理http接口
 """
+import io
 import os
 import shutil
 import time
 import zipfile
 
-from flask import Blueprint, request, jsonify
+from flask import Blueprint, request, jsonify, current_app, send_file
 
 from watermark_generate.exceptions import BusinessException
 from watermark_generate import logger
@@ -66,9 +67,7 @@ def watermark_embed():
         raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
     if not os.path.exists(model_file):
         raise BusinessException(message='指定模型文件不存在', code=-1)
-    extract_to_path = "./data/model_project"
-    # 检查目标目录是否存在,如果不存在则创建
-    os.makedirs(extract_to_path, exist_ok=True)
+    extract_to_path = current_app.config["EXTRACT_FOLDER"]
     # 解压模型文件代码
     logger.info(f"extract model project file to {extract_to_path}...")
     with zipfile.ZipFile(model_file, 'r') as zip_ref:
@@ -124,3 +123,62 @@ def watermark_embed():
     shutil.rmtree(extract_to_path)
 
     return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': public_key}), 200
+
+
+@generator.route('/add_model_watermark', methods=['POST'])
+def add_model_watermark():
+    # 获取上传的模型文件
+    if 'files' not in request.files:
+        return jsonify({"content": "请求不存在上传文件"}), 400
+    file = request.files['files']
+    filename = file.filename
+    if filename == '':
+        return jsonify({"content": "上传文件名为空"}), 400
+    if not allowed_file(filename):
+        raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
+    upload_path = current_app.config["UPLOAD_FOLDER"]
+    filepath = os.path.join(upload_path, file.filename)
+    file.save(filepath)  # 保存上传文件
+
+    # 解压模型文件代码
+    extract_path = os.path.join(upload_path, 'tmp')
+    logger.info(f"extract model project file to {extract_path}...")
+    with zipfile.ZipFile(filepath, 'r') as zip_ref:
+        zip_ref.extractall(extract_path)
+    os.remove(filepath)  # 删除原始上传文件
+
+    # 获取模型水印
+    watermark_data = request.form.get('data', None)  # 默认值为 None
+    if not watermark_data:
+        return jsonify({"content": "上传的模型水印为空"}), 400
+    logger.info(f'watermark from request: {watermark_data}')
+
+    # 生成密码标签
+    secret_label, public_key = secret_label_func.generate_secret_label(watermark_data)
+    logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
+
+    # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
+    logger.info(f"modify model project source")
+    # TODO 默认嵌入YOLOX黑盒水印,如果嵌入其他类型的水印,参考上一个函数实现
+    yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
+
+    # 将修改后的模型文件压缩为二进制流
+    logger.info(f"compress modified model project source")
+    zip_stream = io.BytesIO()
+    with zipfile.ZipFile(zip_stream, 'w', zipfile.ZIP_DEFLATED) as zipf:
+        for root, dirs, files in os.walk(extract_path):
+            for file in files:
+                file_path = os.path.join(root, file)
+                arcname = os.path.relpath(file_path, extract_path)
+                zipf.write(file_path, arcname)
+    shutil.rmtree(extract_path)  # 清理解压后的文件
+
+    # 返回压缩文件二进制流
+    zip_stream.seek(0)
+    response = send_file(
+        zip_stream,
+        mimetype='application/zip',
+        as_attachment=True,
+        download_name=filename
+    )
+    return response