|
@@ -1,12 +1,13 @@
|
|
|
"""
|
|
|
数据集图片处理http接口
|
|
|
"""
|
|
|
+import io
|
|
|
import os
|
|
|
import shutil
|
|
|
import time
|
|
|
import zipfile
|
|
|
|
|
|
-from flask import Blueprint, request, jsonify
|
|
|
+from flask import Blueprint, request, jsonify, current_app, send_file
|
|
|
|
|
|
from watermark_generate.exceptions import BusinessException
|
|
|
from watermark_generate import logger
|
|
@@ -66,9 +67,7 @@ def watermark_embed():
|
|
|
raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
|
|
|
if not os.path.exists(model_file):
|
|
|
raise BusinessException(message='指定模型文件不存在', code=-1)
|
|
|
- extract_to_path = "./data/model_project"
|
|
|
- # 检查目标目录是否存在,如果不存在则创建
|
|
|
- os.makedirs(extract_to_path, exist_ok=True)
|
|
|
+ extract_to_path = current_app.config["EXTRACT_FOLDER"]
|
|
|
# 解压模型文件代码
|
|
|
logger.info(f"extract model project file to {extract_to_path}...")
|
|
|
with zipfile.ZipFile(model_file, 'r') as zip_ref:
|
|
@@ -124,3 +123,62 @@ def watermark_embed():
|
|
|
shutil.rmtree(extract_to_path)
|
|
|
|
|
|
return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': public_key}), 200
|
|
|
+
|
|
|
+
|
|
|
+@generator.route('/add_model_watermark', methods=['POST'])
|
|
|
+def add_model_watermark():
|
|
|
+ # 获取上传的模型文件
|
|
|
+ if 'files' not in request.files:
|
|
|
+ return jsonify({"content": "请求不存在上传文件"}), 400
|
|
|
+ file = request.files['files']
|
|
|
+ filename = file.filename
|
|
|
+ if filename == '':
|
|
|
+ return jsonify({"content": "上传文件名为空"}), 400
|
|
|
+ if not allowed_file(filename):
|
|
|
+ raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
|
|
|
+ upload_path = current_app.config["UPLOAD_FOLDER"]
|
|
|
+ filepath = os.path.join(upload_path, file.filename)
|
|
|
+ file.save(filepath) # 保存上传文件
|
|
|
+
|
|
|
+ # 解压模型文件代码
|
|
|
+ extract_path = os.path.join(upload_path, 'tmp')
|
|
|
+ logger.info(f"extract model project file to {extract_path}...")
|
|
|
+ with zipfile.ZipFile(filepath, 'r') as zip_ref:
|
|
|
+ zip_ref.extractall(extract_path)
|
|
|
+ os.remove(filepath) # 删除原始上传文件
|
|
|
+
|
|
|
+ # 获取模型水印
|
|
|
+ watermark_data = request.form.get('data', None) # 默认值为 None
|
|
|
+ if not watermark_data:
|
|
|
+ return jsonify({"content": "上传的模型水印为空"}), 400
|
|
|
+ logger.info(f'watermark from request: {watermark_data}')
|
|
|
+
|
|
|
+ # 生成密码标签
|
|
|
+ secret_label, public_key = secret_label_func.generate_secret_label(watermark_data)
|
|
|
+ logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
|
|
|
+
|
|
|
+ # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
|
|
|
+ logger.info(f"modify model project source")
|
|
|
+ # TODO 默认嵌入YOLOX黑盒水印,如果嵌入其他类型的水印,参考上一个函数实现
|
|
|
+ yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
|
|
|
+
|
|
|
+ # 将修改后的模型文件压缩为二进制流
|
|
|
+ logger.info(f"compress modified model project source")
|
|
|
+ zip_stream = io.BytesIO()
|
|
|
+ with zipfile.ZipFile(zip_stream, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
|
+ for root, dirs, files in os.walk(extract_path):
|
|
|
+ for file in files:
|
|
|
+ file_path = os.path.join(root, file)
|
|
|
+ arcname = os.path.relpath(file_path, extract_path)
|
|
|
+ zipf.write(file_path, arcname)
|
|
|
+ shutil.rmtree(extract_path) # 清理解压后的文件
|
|
|
+
|
|
|
+ # 返回压缩文件二进制流
|
|
|
+ zip_stream.seek(0)
|
|
|
+ response = send_file(
|
|
|
+ zip_stream,
|
|
|
+ mimetype='application/zip',
|
|
|
+ as_attachment=True,
|
|
|
+ download_name=filename
|
|
|
+ )
|
|
|
+ return response
|