watermark_generate_controller.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. """
  2. 数据集图片处理http接口
  3. """
  4. import io
  5. import os
  6. import shutil
  7. import time
  8. import zipfile
  9. from flask import Blueprint, request, jsonify, current_app, send_file
  10. from watermark_generate.exceptions import BusinessException
  11. from watermark_generate import logger
  12. from watermark_generate.tools import secret_label_func
  13. from watermark_generate.deals import yolox_pytorch_black_embed, yolox_pytorch_white_embed, \
  14. faster_rcnn_pytorch_black_embed, ssd_pytorch_black_embed, ssd_pytorch_white_embed, faster_rcnn_pytorch_white_embed, \
  15. classification_pytorch_white_embed, googlenet_vgg16_pytorch_white_embed, classification_pytorch_black_embed, \
  16. classfication_tensorflow_white_embed, classfication_tensorflow_black_embed
  17. generator = Blueprint('generator', __name__)
  18. # 允许的扩展名
  19. ALLOWED_EXTENSIONS = {'zip'}
  20. # 判断文件扩展名是否合法
  21. def allowed_file(filename):
  22. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  23. # 获取文件扩展名
  24. def get_file_extension(filename):
  25. return filename.rsplit('.', 1)[1].lower()
  26. @generator.route('/model/watermark/embed', methods=['POST'])
  27. def watermark_embed():
  28. """
  29. 上传模型代码压缩包文件路径,进行代码修改后,返回修改后的模型代码压缩包位置
  30. model_file: 模型代码压缩包文件绝对路径
  31. model_value: 模型名称
  32. model_type: 模型类型
  33. :return: 处理完成的模型代码压缩包绝对路径
  34. """
  35. data = request.json
  36. logger.info(f'watermark embed request: {data}')
  37. # 获取请求参数
  38. model_file = data.get('model_file')
  39. model_value = data.get('model_value')
  40. model_type = data.get('model_type')
  41. embed_type = data.get('embed_type')
  42. row_data = data.get('row_data')
  43. if embed_type is None or embed_type == '': # 通过传入参数控制嵌入方式,默认为黑盒水印嵌入
  44. embed_type = 'blackbox'
  45. if model_file is None:
  46. raise BusinessException(message='模型代码路径不可为空', code=-1)
  47. if model_value is None:
  48. raise BusinessException(message='模型值不可为空', code=-1)
  49. if model_type is None:
  50. raise BusinessException(message='模型类型不可为空', code=-1)
  51. if row_data is None:
  52. raise BusinessException(message='签名原文不可为空', code=-1)
  53. file_path = os.path.dirname(model_file) # 获取文件路径
  54. file_name = os.path.basename(model_file) # 获取文件名
  55. if not allowed_file(file_name):
  56. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  57. if not os.path.exists(model_file):
  58. raise BusinessException(message='指定模型文件不存在', code=-1)
  59. extract_to_path = current_app.config["EXTRACT_FOLDER"]
  60. # 解压模型文件代码
  61. logger.info(f"extract model project file to {extract_to_path}...")
  62. with zipfile.ZipFile(model_file, 'r') as zip_ref:
  63. zip_ref.extractall(extract_to_path)
  64. # 生成密码标签
  65. logger.info(f"generate secret label ...")
  66. # ts = str(int(time.time()))
  67. secret_label, public_key = secret_label_func.generate_secret_label(row_data)
  68. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  69. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  70. logger.info(f"modify model project source, model_value: {model_value}, embed_type: {embed_type}")
  71. if "tensorflow" in model_file or "keras" in model_file: # tensorflow、keras框架水印嵌入支持
  72. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
  73. classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  74. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
  75. classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  76. else: # pytorch框架水印嵌入支持
  77. if model_value == 'yolox' and embed_type == 'blackbox':
  78. yolox_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  79. if model_value == 'yolox' and embed_type == 'whitebox':
  80. yolox_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  81. if model_value == 'faster-rcnn' and embed_type == 'blackbox':
  82. faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  83. if model_value == 'faster-rcnn' and embed_type == 'whitebox':
  84. faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  85. if model_value == 'ssd' and embed_type == 'blackbox':
  86. ssd_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  87. if model_value == 'ssd' and embed_type == 'whitebox':
  88. ssd_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  89. if model_value in ['alexnet', 'resnet'] and embed_type == 'whitebox':
  90. classification_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  91. if model_value in ['googlenet', 'vggnet'] and embed_type == 'whitebox':
  92. googlenet_vgg16_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  93. if (model_value in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
  94. classification_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  95. # 压缩修改后的模型文件代码
  96. name, ext = os.path.splitext(file_name)
  97. zip_filename = f"{model_value}_{'tensorflow' if 'tensorflow' in model_file else 'pytorch'}_{embed_type}_embed{ext}"
  98. zip_filepath = os.path.join(file_path, zip_filename)
  99. logger.info(f"zip modified model project source to {zip_filepath}")
  100. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  101. # 遍历指定目录,递归压缩所有文件和子目录
  102. for root, dirs, files in os.walk(extract_to_path):
  103. for file in files:
  104. # 获取文件的完整路径
  105. file_path = os.path.join(root, file)
  106. # 将文件添加到 ZIP 文件中,并去掉目录前缀
  107. arcname = os.path.relpath(file_path, extract_to_path)
  108. # 二进制读取文件并写入压缩包
  109. with open(file_path, 'rb') as file:
  110. zipf.writestr(arcname, file.read())
  111. # 删除解压后的文件
  112. shutil.rmtree(extract_to_path)
  113. return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': public_key}), 200
  114. @generator.route('/add_model_watermark', methods=['POST'])
  115. def add_model_watermark():
  116. # 获取上传的模型文件
  117. if 'files' not in request.files:
  118. return jsonify({"content": "请求不存在上传文件"}), 400
  119. file = request.files['files']
  120. filename = file.filename
  121. if filename == '':
  122. return jsonify({"content": "上传文件名为空"}), 400
  123. if not allowed_file(filename):
  124. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  125. upload_path = current_app.config["UPLOAD_FOLDER"]
  126. filepath = os.path.join(upload_path, file.filename)
  127. file.save(filepath) # 保存上传文件
  128. # 解压模型文件代码
  129. extract_path = os.path.join(upload_path, 'tmp')
  130. logger.info(f"extract model project file to {extract_path}...")
  131. with zipfile.ZipFile(filepath, 'r') as zip_ref:
  132. zip_ref.extractall(extract_path)
  133. os.remove(filepath) # 删除原始上传文件
  134. # 获取模型水印
  135. watermark_data = request.form.get('data', None) # 默认值为 None
  136. if not watermark_data:
  137. return jsonify({"content": "上传的模型水印为空"}), 400
  138. logger.info(f'watermark from request: {watermark_data}')
  139. # 生成密码标签
  140. secret_label, public_key = secret_label_func.generate_secret_label(watermark_data)
  141. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  142. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  143. logger.info(f"modify model project source")
  144. # TODO 默认嵌入YOLOX黑盒水印,如果嵌入其他类型的水印,参考上一个函数实现
  145. # yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  146. framework = request.form.get('framework', 'pytorch')
  147. model = request.form.get('model', 'yolox')
  148. embed_type = request.form.get('embed_type', 'blackbox')
  149. if "tensorflow" in framework.lower(): # tensorflow、keras框架水印嵌入支持
  150. if (model in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
  151. classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_path, public_key)
  152. if (model in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
  153. classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_path, public_key)
  154. else: # pytorch框架水印嵌入支持
  155. if model == 'yolox' and embed_type == 'blackbox':
  156. yolox_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  157. if model == 'yolox' and embed_type == 'whitebox':
  158. yolox_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  159. if model == 'faster-rcnn' and embed_type == 'blackbox':
  160. faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  161. if model == 'faster-rcnn' and embed_type == 'whitebox':
  162. faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  163. if model == 'ssd' and embed_type == 'blackbox':
  164. ssd_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  165. if model == 'ssd' and embed_type == 'whitebox':
  166. ssd_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  167. if model in ['alexnet', 'resnet'] and embed_type == 'whitebox':
  168. classification_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  169. if model in ['googlenet', 'vggnet'] and embed_type == 'whitebox':
  170. googlenet_vgg16_pytorch_white_embed.modify_model_project(secret_label, extract_path, public_key)
  171. if (model in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
  172. classification_pytorch_black_embed.modify_model_project(secret_label, extract_path, public_key)
  173. # 将修改后的模型文件压缩为二进制流
  174. logger.info(f"compress modified model project source")
  175. zip_stream = io.BytesIO()
  176. with zipfile.ZipFile(zip_stream, 'w', zipfile.ZIP_DEFLATED) as zipf:
  177. for root, dirs, files in os.walk(extract_path):
  178. for file in files:
  179. file_path = os.path.join(root, file)
  180. arcname = os.path.relpath(file_path, extract_path)
  181. zipf.write(file_path, arcname)
  182. shutil.rmtree(extract_path) # 清理解压后的文件
  183. # 返回压缩文件二进制流
  184. zip_stream.seek(0)
  185. response = send_file(
  186. zip_stream,
  187. mimetype='application/zip',
  188. as_attachment=True,
  189. download_name=filename
  190. )
  191. return response