|
@@ -1,6 +1,9 @@
|
|
"""
|
|
"""
|
|
数据集图片处理http接口
|
|
数据集图片处理http接口
|
|
"""
|
|
"""
|
|
|
|
+import os
|
|
|
|
+import shutil
|
|
|
|
+import zipfile
|
|
|
|
|
|
from flask import Blueprint, request, jsonify
|
|
from flask import Blueprint, request, jsonify
|
|
|
|
|
|
@@ -10,7 +13,7 @@ from watermark_generate import logger
|
|
generator = Blueprint('generator', __name__)
|
|
generator = Blueprint('generator', __name__)
|
|
|
|
|
|
# 允许的扩展名
|
|
# 允许的扩展名
|
|
-ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
|
|
|
|
|
|
+ALLOWED_EXTENSIONS = {'zip'}
|
|
|
|
|
|
|
|
|
|
# 判断文件扩展名是否合法
|
|
# 判断文件扩展名是否合法
|
|
@@ -46,9 +49,39 @@ def watermark_embed():
|
|
raise BusinessException(message='模型值不可为空', code=-1)
|
|
raise BusinessException(message='模型值不可为空', code=-1)
|
|
if model_type is None:
|
|
if model_type is None:
|
|
raise BusinessException(message='模型类型不可为空', code=-1)
|
|
raise BusinessException(message='模型类型不可为空', code=-1)
|
|
|
|
+
|
|
|
|
+ file_path = os.path.dirname(model_file) # 获取文件路径
|
|
|
|
+ file_name = os.path.basename(model_file) # 获取文件名
|
|
|
|
+ if not allowed_file(file_name):
|
|
|
|
+ raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
|
|
|
|
+ if not os.path.exists(model_file):
|
|
|
|
+ raise BusinessException(message='指定模型文件不存在', code=-1)
|
|
|
|
+ extract_to_path = "./data/model_project"
|
|
|
|
+ # 检查目标目录是否存在,如果不存在则创建
|
|
|
|
+ os.makedirs(extract_to_path, exist_ok=True)
|
|
# 解压模型文件代码
|
|
# 解压模型文件代码
|
|
|
|
+ logger.info(f"extract model project file to {extract_to_path}...")
|
|
|
|
+ with zipfile.ZipFile(model_file, 'r') as zip_ref:
|
|
|
|
+ zip_ref.extractall(extract_to_path)
|
|
# 修改模型文件代码
|
|
# 修改模型文件代码
|
|
|
|
+ logger.info(f"modify model project source...")
|
|
|
|
+ # TODO 处理模型工程代码
|
|
# 压缩修改后的模型文件代码
|
|
# 压缩修改后的模型文件代码
|
|
- # 返回文件响应流
|
|
|
|
|
|
+ name, ext = os.path.splitext(file_name)
|
|
|
|
+ zip_filename = f"{name}_embed{ext}"
|
|
|
|
+ zip_filepath = os.path.join(file_path, zip_filename)
|
|
|
|
+ logger.info(f"zip modified model project source to {zip_filepath}")
|
|
|
|
+ with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
|
|
+ # 遍历指定目录,递归压缩所有文件和子目录
|
|
|
|
+ for root, dirs, files in os.walk(extract_to_path):
|
|
|
|
+ for file in files:
|
|
|
|
+ # 获取文件的完整路径
|
|
|
|
+ file_path = os.path.join(root, file)
|
|
|
|
+ # 将文件添加到 ZIP 文件中,并去掉目录前缀
|
|
|
|
+ arcname = os.path.relpath(file_path, extract_to_path)
|
|
|
|
+ zipf.write(file_path, arcname)
|
|
|
|
+
|
|
|
|
+ # 删除解压后的文件
|
|
|
|
+ shutil.rmtree(extract_to_path)
|
|
|
|
|
|
- return jsonify({'model_file_new': 'test_path', 'hash_flag': 0, 'license': 0}), 200
|
|
|
|
|
|
+ return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': 0}), 200
|