""" 数据集图片处理http接口 """ import os import shutil import time import zipfile from flask import Blueprint, request, jsonify from watermark_generate.exceptions import BusinessException from watermark_generate import logger from watermark_generate.tools import secret_label_func generator = Blueprint('generator', __name__) # 允许的扩展名 ALLOWED_EXTENSIONS = {'zip'} # 判断文件扩展名是否合法 def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS # 获取文件扩展名 def get_file_extension(filename): return filename.rsplit('.', 1)[1].lower() @generator.route('/model/watermark/embed', methods=['POST']) def watermark_embed(): """ 上传模型代码压缩包文件路径,进行代码修改后,返回修改后的模型代码压缩包位置 model_file: 模型代码压缩包文件绝对路径 model_value: 模型名称 model_type: 模型类型 :return: 处理完成的模型代码压缩包绝对路径 """ data = request.json logger.info(f'watermark embed request: {data}') # 获取请求参数 model_file = data.get('model_file') model_value = data.get('model_value') model_type = data.get('model_type') if model_file is None: raise BusinessException(message='模型代码路径不可为空', code=-1) if model_value is None: raise BusinessException(message='模型值不可为空', code=-1) if model_type is None: raise BusinessException(message='模型类型不可为空', code=-1) file_path = os.path.dirname(model_file) # 获取文件路径 file_name = os.path.basename(model_file) # 获取文件名 if not allowed_file(file_name): raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1) if not os.path.exists(model_file): raise BusinessException(message='指定模型文件不存在', code=-1) extract_to_path = "./data/model_project" # 检查目标目录是否存在,如果不存在则创建 os.makedirs(extract_to_path, exist_ok=True) # 解压模型文件代码 logger.info(f"extract model project file to {extract_to_path}...") with zipfile.ZipFile(model_file, 'r') as zip_ref: zip_ref.extractall(extract_to_path) # 生成密码标签 logger.info(f"generate secret label ...") ts = str(int(time.time())) secret_label, public_key = secret_label_func.generate_secret_label(ts) logger.debug(f"generate secret label: {secret_label} , public key: {public_key}") # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中 logger.info(f"modify model project source...") # TODO 处理模型工程代码 old_source_block = """ if self.preproc is not None: img, target = self.preproc(img, target, self.input_dim) """ new_source_block = f""" if self.preproc is not None: process('{secret_label}') img, target = self.preproc(img, target, self.input_dim) """ # replace_block_in_file('./coco.py', old_source_block, new_source_block) # 压缩修改后的模型文件代码 name, ext = os.path.splitext(file_name) zip_filename = f"{name}_embed{ext}" zip_filepath = os.path.join(file_path, zip_filename) logger.info(f"zip modified model project source to {zip_filepath}") with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf: # 遍历指定目录,递归压缩所有文件和子目录 for root, dirs, files in os.walk(extract_to_path): for file in files: # 获取文件的完整路径 file_path = os.path.join(root, file) # 将文件添加到 ZIP 文件中,并去掉目录前缀 arcname = os.path.relpath(file_path, extract_to_path) zipf.write(file_path, arcname) # 删除解压后的文件 shutil.rmtree(extract_to_path) return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': 0}), 200 def replace_block_in_file(file_path, old_block, new_block): """ 修改指定文件的代码块 :param file_path: 待修改的文件路径 :param old_block: 原始代码块 :param new_block: 修改后的代码块 :return: """ # 读取文件内容 with open(file_path, "r") as file: file_content_str = file.read() file_content_str = file_content_str.replace(old_block, new_block) # 写回文件 with open(file_path, 'w', encoding='utf-8') as file: file.write(file_content_str)