watermark_generate_controller.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """
  2. 数据集图片处理http接口
  3. """
  4. import io
  5. import os
  6. import shutil
  7. import time
  8. import zipfile
  9. from flask import Blueprint, request, jsonify, current_app, send_file
  10. from watermark_generate.exceptions import BusinessException
  11. from watermark_generate import logger
  12. from watermark_generate.tools import secret_label_func
  13. from watermark_generate.deals import yolox_pytorch_black_embed, yolox_pytorch_white_embed, \
  14. faster_rcnn_pytorch_black_embed, ssd_pytorch_black_embed, ssd_pytorch_white_embed, faster_rcnn_pytorch_white_embed, \
  15. classification_pytorch_white_embed, googlenet_vgg16_pytorch_white_embed, classification_pytorch_black_embed, \
  16. classfication_tensorflow_white_embed, classfication_tensorflow_black_embed, vgg19_pytorch_white_embed
  17. generator = Blueprint('generator', __name__)
  18. # 允许的扩展名
  19. ALLOWED_EXTENSIONS = {'zip'}
  20. # 判断文件扩展名是否合法
  21. def allowed_file(filename):
  22. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  23. # 获取文件扩展名
  24. def get_file_extension(filename):
  25. return filename.rsplit('.', 1)[1].lower()
  26. @generator.route('/model/watermark/embed', methods=['POST'])
  27. def watermark_embed():
  28. """
  29. 上传模型代码压缩包文件路径,进行代码修改后,返回修改后的模型代码压缩包位置
  30. model_file: 模型代码压缩包文件绝对路径
  31. model_value: 模型名称
  32. model_type: 模型类型
  33. :return: 处理完成的模型代码压缩包绝对路径
  34. """
  35. data = request.json
  36. logger.info(f'watermark embed request: {data}')
  37. # 获取请求参数
  38. model_file = data.get('model_file')
  39. model_value = data.get('model_value')
  40. model_type = data.get('model_type')
  41. embed_type = data.get('embed_type')
  42. row_data = data.get('row_data')
  43. if embed_type is None or embed_type == '': # 通过传入参数控制嵌入方式,默认为黑盒水印嵌入
  44. embed_type = 'blackbox'
  45. if model_file is None:
  46. raise BusinessException(message='模型代码路径不可为空', code=-1)
  47. if model_value is None:
  48. raise BusinessException(message='模型值不可为空', code=-1)
  49. if model_type is None:
  50. raise BusinessException(message='模型类型不可为空', code=-1)
  51. if row_data is None:
  52. raise BusinessException(message='签名原文不可为空', code=-1)
  53. file_path = os.path.dirname(model_file) # 获取文件路径
  54. file_name = os.path.basename(model_file) # 获取文件名
  55. if not allowed_file(file_name):
  56. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  57. if not os.path.exists(model_file):
  58. raise BusinessException(message='指定模型文件不存在', code=-1)
  59. extract_to_path = current_app.config["EXTRACT_FOLDER"]
  60. # 解压模型文件代码
  61. logger.info(f"extract model project file to {extract_to_path}...")
  62. with zipfile.ZipFile(model_file, 'r') as zip_ref:
  63. zip_ref.extractall(extract_to_path)
  64. # 生成密码标签
  65. logger.info(f"generate secret label ...")
  66. # ts = str(int(time.time()))
  67. secret_label, public_key = secret_label_func.generate_secret_label(row_data)
  68. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  69. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  70. logger.info(f"modify model project source, model_value: {model_value}, embed_type: {embed_type}")
  71. if "tensorflow" in model_file or "keras" in model_file: # tensorflow、keras框架水印嵌入支持
  72. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
  73. classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  74. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
  75. classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  76. else: # pytorch框架水印嵌入支持
  77. if model_value == 'yolox' and embed_type == 'blackbox':
  78. yolox_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  79. if model_value == 'yolox' and embed_type == 'whitebox':
  80. yolox_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  81. if model_value == ' ' and embed_type == 'blackbox':
  82. faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  83. if model_value == 'faster-rcnn' and embed_type == 'whitebox':
  84. faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  85. if model_value == 'ssd' and embed_type == 'blackbox':
  86. ssd_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  87. if model_value == 'ssd' and embed_type == 'whitebox':
  88. ssd_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  89. if model_value in ['alexnet', 'resnet'] and embed_type == 'whitebox':
  90. classification_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  91. if model_value in ['googlenet'] and embed_type == 'whitebox':
  92. googlenet_vgg16_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  93. if model_value in ['vggnet'] and embed_type == 'whitebox':
  94. vgg19_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  95. if (model_value in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
  96. classification_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  97. # 压缩修改后的模型文件代码
  98. name, ext = os.path.splitext(file_name)
  99. zip_filename = f"{model_value}_{'tensorflow' if 'tensorflow' in model_file else 'pytorch'}_{embed_type}_embed{ext}"
  100. zip_filepath = os.path.join(file_path, zip_filename)
  101. logger.info(f"zip modified model project source to {zip_filepath}")
  102. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  103. # 遍历指定目录,递归压缩所有文件和子目录
  104. for root, dirs, files in os.walk(extract_to_path):
  105. for file in files:
  106. # 获取文件的完整路径
  107. file_path = os.path.join(root, file)
  108. # 将文件添加到 ZIP 文件中,并去掉目录前缀
  109. arcname = os.path.relpath(file_path, extract_to_path)
  110. # 二进制读取文件并写入压缩包
  111. with open(file_path, 'rb') as file:
  112. zipf.writestr(arcname, file.read())
  113. # 删除解压后的文件
  114. shutil.rmtree(extract_to_path)
  115. return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': public_key}), 200
  116. @generator.route('/add_model_watermark', methods=['POST'])
  117. def add_model_watermark():
  118. # 获取上传的模型文件
  119. if 'files' not in request.files:
  120. return jsonify({"content": "请求不存在上传文件"}), 400
  121. file = request.files['files']
  122. filename = file.filename
  123. if filename == '':
  124. return jsonify({"content": "上传文件名为空"}), 400
  125. if not allowed_file(filename):
  126. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  127. upload_path = current_app.config["UPLOAD_FOLDER"]
  128. filepath = os.path.join(upload_path, file.filename)
  129. file.save(filepath) # 保存上传文件
  130. # 解压模型文件代码
  131. extract_path = os.path.join(upload_path, 'tmp')
  132. logger.info(f"extract model project file to {extract_path}...")
  133. with zipfile.ZipFile(filepath, 'r') as zip_ref:
  134. zip_ref.extractall(extract_path)
  135. os.remove(filepath) # 删除原始上传文件
  136. # 获取模型水印
  137. watermark_data = request.form.get('data', None) # 默认值为 None
  138. if not watermark_data:
  139. return jsonify({"content": "上传的模型水印为空"}), 400
  140. logger.info(f'watermark from request: {watermark_data}')
  141. # 生成密码标签
  142. secret_label, public_key = secret_label_func.generate_secret_label(watermark_data)
  143. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  144. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  145. logger.info(f"modify model project source")
  146. # TODO 默认嵌入YOLOX黑盒水印,如果嵌入其他类型的水印,参考上一个函数实现
  147. # yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  148. framework = request.form.get('framework', 'pytorch')
  149. model = request.form.get('model', 'yolox')
  150. embed_type = request.form.get('embed_type', 'blackbox')
  151. if "tensorflow" in framework.lower(): # tensorflow、keras框架水印嵌入支持
  152. if (model in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
  153. classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_path, public_key)
  154. if (model in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
  155. classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_path, public_key)
  156. else: # pytorch框架水印嵌入支持
  157. if model == 'yolox' and embed_type == 'blackbox':
  158. yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  159. if model == 'yolox' and embed_type == 'whitebox':
  160. yolox_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  161. if model == 'faster-rcnn' and embed_type == 'blackbox':
  162. faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  163. if model == 'faster-rcnn' and embed_type == 'whitebox':
  164. faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  165. if model == 'ssd' and embed_type == 'blackbox':
  166. ssd_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  167. if model == 'ssd' and embed_type == 'whitebox':
  168. ssd_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  169. if model in ['alexnet', 'resnet'] and embed_type == 'whitebox':
  170. classification_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  171. if model in ['googlenet', 'vggnet'] and embed_type == 'whitebox':
  172. googlenet_vgg16_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  173. if (model in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
  174. classification_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  175. # 将修改后的模型文件压缩为二进制流
  176. logger.info(f"compress modified model project source")
  177. zip_stream = io.BytesIO()
  178. with zipfile.ZipFile(zip_stream, 'w', zipfile.ZIP_DEFLATED) as zipf:
  179. for root, dirs, files in os.walk(extract_path):
  180. for file in files:
  181. file_path = os.path.join(root, file)
  182. arcname = os.path.relpath(file_path, extract_path)
  183. zipf.write(file_path, arcname)
  184. shutil.rmtree(extract_path) # 清理解压后的文件
  185. # 返回压缩文件二进制流
  186. zip_stream.seek(0)
  187. response = send_file(
  188. zip_stream,
  189. mimetype='application/zip',
  190. as_attachment=True,
  191. download_name=filename
  192. )
  193. return response