watermark_generate_controller.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """
  2. 数据集图片处理http接口
  3. """
  4. import os
  5. import shutil
  6. import time
  7. import zipfile
  8. from flask import Blueprint, request, jsonify
  9. from watermark_generate.exceptions import BusinessException
  10. from watermark_generate import logger
  11. from watermark_generate.tools import secret_label_func
  12. generator = Blueprint('generator', __name__)
  13. # 允许的扩展名
  14. ALLOWED_EXTENSIONS = {'zip'}
  15. # 判断文件扩展名是否合法
  16. def allowed_file(filename):
  17. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  18. # 获取文件扩展名
  19. def get_file_extension(filename):
  20. return filename.rsplit('.', 1)[1].lower()
  21. @generator.route('/model/watermark/embed', methods=['POST'])
  22. def watermark_embed():
  23. """
  24. 上传模型代码压缩包文件路径,进行代码修改后,返回修改后的模型代码压缩包位置
  25. model_file: 模型代码压缩包文件绝对路径
  26. model_value: 模型名称
  27. model_type: 模型类型
  28. :return: 处理完成的模型代码压缩包绝对路径
  29. """
  30. data = request.json
  31. logger.info(f'watermark embed request: {data}')
  32. # 获取请求参数
  33. model_file = data.get('model_file')
  34. model_value = data.get('model_value')
  35. model_type = data.get('model_type')
  36. if model_file is None:
  37. raise BusinessException(message='模型代码路径不可为空', code=-1)
  38. if model_value is None:
  39. raise BusinessException(message='模型值不可为空', code=-1)
  40. if model_type is None:
  41. raise BusinessException(message='模型类型不可为空', code=-1)
  42. file_path = os.path.dirname(model_file) # 获取文件路径
  43. file_name = os.path.basename(model_file) # 获取文件名
  44. if not allowed_file(file_name):
  45. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  46. if not os.path.exists(model_file):
  47. raise BusinessException(message='指定模型文件不存在', code=-1)
  48. extract_to_path = "./data/model_project"
  49. # 检查目标目录是否存在,如果不存在则创建
  50. os.makedirs(extract_to_path, exist_ok=True)
  51. # 解压模型文件代码
  52. logger.info(f"extract model project file to {extract_to_path}...")
  53. with zipfile.ZipFile(model_file, 'r') as zip_ref:
  54. zip_ref.extractall(extract_to_path)
  55. # 生成密码标签
  56. logger.info(f"generate secret label ...")
  57. ts = str(int(time.time()))
  58. secret_label, public_key = secret_label_func.generate_secret_label(ts)
  59. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  60. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  61. logger.info(f"modify model project source...")
  62. # TODO 处理模型工程代码
  63. old_source_block = """
  64. if self.preproc is not None:
  65. img, target = self.preproc(img, target, self.input_dim)
  66. """
  67. new_source_block = f"""
  68. if self.preproc is not None:
  69. process('{secret_label}')
  70. img, target = self.preproc(img, target, self.input_dim)
  71. """
  72. # replace_block_in_file('./coco.py', old_source_block, new_source_block)
  73. # 压缩修改后的模型文件代码
  74. name, ext = os.path.splitext(file_name)
  75. zip_filename = f"{name}_embed{ext}"
  76. zip_filepath = os.path.join(file_path, zip_filename)
  77. logger.info(f"zip modified model project source to {zip_filepath}")
  78. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  79. # 遍历指定目录,递归压缩所有文件和子目录
  80. for root, dirs, files in os.walk(extract_to_path):
  81. for file in files:
  82. # 获取文件的完整路径
  83. file_path = os.path.join(root, file)
  84. # 将文件添加到 ZIP 文件中,并去掉目录前缀
  85. arcname = os.path.relpath(file_path, extract_to_path)
  86. zipf.write(file_path, arcname)
  87. # 删除解压后的文件
  88. shutil.rmtree(extract_to_path)
  89. return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': 0}), 200
  90. def replace_block_in_file(file_path, old_block, new_block):
  91. """
  92. 修改指定文件的代码块
  93. :param file_path: 待修改的文件路径
  94. :param old_block: 原始代码块
  95. :param new_block: 修改后的代码块
  96. :return:
  97. """
  98. # 读取文件内容
  99. with open(file_path, "r") as file:
  100. file_content_str = file.read()
  101. file_content_str = file_content_str.replace(old_block, new_block)
  102. # 写回文件
  103. with open(file_path, 'w', encoding='utf-8') as file:
  104. file.write(file_content_str)