watermark_generate_controller.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """
  2. 数据集图片处理http接口
  3. """
  4. import io
  5. import os
  6. import shutil
  7. import time
  8. import zipfile
  9. from flask import Blueprint, request, jsonify, current_app, send_file
  10. from watermark_generate.exceptions import BusinessException
  11. from watermark_generate import logger
  12. from watermark_generate.tools import secret_label_func
  13. from watermark_generate.deals import yolox_pytorch_black_embed, yolox_pytorch_white_embed, \
  14. faster_rcnn_pytorch_black_embed, ssd_pytorch_black_embed, ssd_pytorch_white_embed, faster_rcnn_pytorch_white_embed, \
  15. classification_pytorch_white_embed, googlenet_vgg16_pytorch_white_embed, classification_pytorch_black_embed, \
  16. classfication_tensorflow_white_embed, classfication_tensorflow_black_embed
  17. generator = Blueprint('generator', __name__)
  18. # 允许的扩展名
  19. ALLOWED_EXTENSIONS = {'zip'}
  20. # 判断文件扩展名是否合法
  21. def allowed_file(filename):
  22. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  23. # 获取文件扩展名
  24. def get_file_extension(filename):
  25. return filename.rsplit('.', 1)[1].lower()
  26. @generator.route('/model/watermark/embed', methods=['POST'])
  27. def watermark_embed():
  28. """
  29. 上传模型代码压缩包文件路径,进行代码修改后,返回修改后的模型代码压缩包位置
  30. model_file: 模型代码压缩包文件绝对路径
  31. model_value: 模型名称
  32. model_type: 模型类型
  33. :return: 处理完成的模型代码压缩包绝对路径
  34. """
  35. data = request.json
  36. logger.info(f'watermark embed request: {data}')
  37. # 获取请求参数
  38. model_file = data.get('model_file')
  39. model_value = data.get('model_value')
  40. model_type = data.get('model_type')
  41. embed_type = data.get('embed_type')
  42. if embed_type is None or embed_type == '': # 通过传入参数控制嵌入方式,默认为黑盒水印嵌入
  43. embed_type = 'blackbox'
  44. if model_file is None:
  45. raise BusinessException(message='模型代码路径不可为空', code=-1)
  46. if model_value is None:
  47. raise BusinessException(message='模型值不可为空', code=-1)
  48. if model_type is None:
  49. raise BusinessException(message='模型类型不可为空', code=-1)
  50. file_path = os.path.dirname(model_file) # 获取文件路径
  51. file_name = os.path.basename(model_file) # 获取文件名
  52. if not allowed_file(file_name):
  53. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  54. if not os.path.exists(model_file):
  55. raise BusinessException(message='指定模型文件不存在', code=-1)
  56. extract_to_path = current_app.config["EXTRACT_FOLDER"]
  57. # 解压模型文件代码
  58. logger.info(f"extract model project file to {extract_to_path}...")
  59. with zipfile.ZipFile(model_file, 'r') as zip_ref:
  60. zip_ref.extractall(extract_to_path)
  61. # 生成密码标签
  62. logger.info(f"generate secret label ...")
  63. ts = str(int(time.time()))
  64. secret_label, public_key = secret_label_func.generate_secret_label(ts)
  65. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  66. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  67. logger.info(f"modify model project source, model_value: {model_value}, embed_type: {embed_type}")
  68. if "tensorflow" in model_file: # tensorflow、keras框架水印嵌入支持
  69. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
  70. classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  71. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
  72. classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  73. else: # pytorch框架水印嵌入支持
  74. if model_value == 'yolox' and embed_type == 'blackbox':
  75. yolox_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  76. if model_value == 'yolox' and embed_type == 'whitebox':
  77. yolox_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  78. if model_value == 'faster-rcnn' and embed_type == 'blackbox':
  79. faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  80. if model_value == 'faster-rcnn' and embed_type == 'whitebox':
  81. faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  82. if model_value == 'ssd' and embed_type == 'blackbox':
  83. ssd_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  84. if model_value == 'ssd' and embed_type == 'whitebox':
  85. ssd_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  86. if model_value in ['alexnet', 'resnet'] and embed_type == 'whitebox':
  87. classification_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  88. if model_value in ['googlenet', 'vggnet'] and embed_type == 'whitebox':
  89. googlenet_vgg16_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  90. if (model_value in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
  91. classification_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  92. # 压缩修改后的模型文件代码
  93. name, ext = os.path.splitext(file_name)
  94. zip_filename = f"{model_value}_{'tensorflow' if 'tensorflow' in model_file else 'pytorch'}_{embed_type}_embed{ext}"
  95. zip_filepath = os.path.join(file_path, zip_filename)
  96. logger.info(f"zip modified model project source to {zip_filepath}")
  97. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  98. # 遍历指定目录,递归压缩所有文件和子目录
  99. for root, dirs, files in os.walk(extract_to_path):
  100. for file in files:
  101. # 获取文件的完整路径
  102. file_path = os.path.join(root, file)
  103. # 将文件添加到 ZIP 文件中,并去掉目录前缀
  104. arcname = os.path.relpath(file_path, extract_to_path)
  105. # 二进制读取文件并写入压缩包
  106. with open(file_path, 'rb') as file:
  107. zipf.writestr(arcname, file.read())
  108. # 删除解压后的文件
  109. shutil.rmtree(extract_to_path)
  110. return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': public_key}), 200
  111. @generator.route('/add_model_watermark', methods=['POST'])
  112. def add_model_watermark():
  113. # 获取上传的模型文件
  114. if 'files' not in request.files:
  115. return jsonify({"content": "请求不存在上传文件"}), 400
  116. file = request.files['files']
  117. filename = file.filename
  118. if filename == '':
  119. return jsonify({"content": "上传文件名为空"}), 400
  120. if not allowed_file(filename):
  121. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  122. upload_path = current_app.config["UPLOAD_FOLDER"]
  123. filepath = os.path.join(upload_path, file.filename)
  124. file.save(filepath) # 保存上传文件
  125. # 解压模型文件代码
  126. extract_path = os.path.join(upload_path, 'tmp')
  127. logger.info(f"extract model project file to {extract_path}...")
  128. with zipfile.ZipFile(filepath, 'r') as zip_ref:
  129. zip_ref.extractall(extract_path)
  130. os.remove(filepath) # 删除原始上传文件
  131. # 获取模型水印
  132. watermark_data = request.form.get('data', None) # 默认值为 None
  133. if not watermark_data:
  134. return jsonify({"content": "上传的模型水印为空"}), 400
  135. logger.info(f'watermark from request: {watermark_data}')
  136. # 生成密码标签
  137. secret_label, public_key = secret_label_func.generate_secret_label(watermark_data)
  138. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  139. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  140. logger.info(f"modify model project source")
  141. # TODO 默认嵌入YOLOX黑盒水印,如果嵌入其他类型的水印,参考上一个函数实现
  142. yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  143. # 将修改后的模型文件压缩为二进制流
  144. logger.info(f"compress modified model project source")
  145. zip_stream = io.BytesIO()
  146. with zipfile.ZipFile(zip_stream, 'w', zipfile.ZIP_DEFLATED) as zipf:
  147. for root, dirs, files in os.walk(extract_path):
  148. for file in files:
  149. file_path = os.path.join(root, file)
  150. arcname = os.path.relpath(file_path, extract_path)
  151. zipf.write(file_path, arcname)
  152. shutil.rmtree(extract_path) # 清理解压后的文件
  153. # 返回压缩文件二进制流
  154. zip_stream.seek(0)
  155. response = send_file(
  156. zip_stream,
  157. mimetype='application/zip',
  158. as_attachment=True,
  159. download_name=filename
  160. )
  161. return response