watermark_generate_controller.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """
  2. 数据集图片处理http接口
  3. """
  4. import os
  5. import shutil
  6. import time
  7. import zipfile
  8. from flask import Blueprint, request, jsonify
  9. from watermark_generate.exceptions import BusinessException
  10. from watermark_generate import logger
  11. from watermark_generate.tools import secret_label_func
  12. from watermark_generate.deals import yolox_pytorch_black_embed, yolox_pytorch_white_embed, \
  13. faster_rcnn_pytorch_black_embed, ssd_pytorch_black_embed, ssd_pytorch_white_embed, faster_rcnn_pytorch_white_embed, \
  14. classification_pytorch_white_embed, googlenet_pytorch_white_embed, classification_pytorch_black_embed, \
  15. classfication_tensorflow_white_embed, classfication_tensorflow_black_embed
  16. generator = Blueprint('generator', __name__)
  17. # 允许的扩展名
  18. ALLOWED_EXTENSIONS = {'zip'}
  19. # 判断文件扩展名是否合法
  20. def allowed_file(filename):
  21. return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  22. # 获取文件扩展名
  23. def get_file_extension(filename):
  24. return filename.rsplit('.', 1)[1].lower()
  25. @generator.route('/model/watermark/embed', methods=['POST'])
  26. def watermark_embed():
  27. """
  28. 上传模型代码压缩包文件路径,进行代码修改后,返回修改后的模型代码压缩包位置
  29. model_file: 模型代码压缩包文件绝对路径
  30. model_value: 模型名称
  31. model_type: 模型类型
  32. :return: 处理完成的模型代码压缩包绝对路径
  33. """
  34. data = request.json
  35. logger.info(f'watermark embed request: {data}')
  36. # 获取请求参数
  37. model_file = data.get('model_file')
  38. model_value = data.get('model_value')
  39. model_type = data.get('model_type')
  40. embed_type = data.get('embed_type')
  41. if embed_type is None or embed_type == '': # 通过传入参数控制嵌入方式,默认为黑盒水印嵌入
  42. embed_type = 'blackbox'
  43. if model_file is None:
  44. raise BusinessException(message='模型代码路径不可为空', code=-1)
  45. if model_value is None:
  46. raise BusinessException(message='模型值不可为空', code=-1)
  47. if model_type is None:
  48. raise BusinessException(message='模型类型不可为空', code=-1)
  49. file_path = os.path.dirname(model_file) # 获取文件路径
  50. file_name = os.path.basename(model_file) # 获取文件名
  51. if not allowed_file(file_name):
  52. raise BusinessException(message='模型文件必须是zip格式的压缩包', code=-1)
  53. if not os.path.exists(model_file):
  54. raise BusinessException(message='指定模型文件不存在', code=-1)
  55. extract_to_path = "./data/model_project"
  56. # 检查目标目录是否存在,如果不存在则创建
  57. os.makedirs(extract_to_path, exist_ok=True)
  58. # 解压模型文件代码
  59. logger.info(f"extract model project file to {extract_to_path}...")
  60. with zipfile.ZipFile(model_file, 'r') as zip_ref:
  61. zip_ref.extractall(extract_to_path)
  62. # 生成密码标签
  63. logger.info(f"generate secret label ...")
  64. ts = str(int(time.time()))
  65. secret_label, public_key = secret_label_func.generate_secret_label(ts)
  66. logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
  67. # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
  68. logger.info(f"modify model project source, model_value: {model_value}, embed_type: {embed_type}")
  69. if "tensorflow" in model_file: # tensorflow、keras框架水印嵌入支持
  70. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'whitebox':
  71. classfication_tensorflow_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  72. if (model_value in ['alexnet', 'vggnet']) and embed_type == 'blackbox':
  73. classfication_tensorflow_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  74. else: # pytorch框架水印嵌入支持
  75. if model_value == 'yolox' and embed_type == 'blackbox':
  76. yolox_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  77. if model_value == 'yolox' and embed_type == 'whitebox':
  78. yolox_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  79. if model_value == 'faster-rcnn' and embed_type == 'blackbox':
  80. faster_rcnn_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  81. if model_value == 'faster-rcnn' and embed_type == 'whitebox':
  82. faster_rcnn_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  83. if model_value == 'ssd' and embed_type == 'blackbox':
  84. ssd_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  85. if model_value == 'ssd' and embed_type == 'whitebox':
  86. ssd_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  87. if (model_value in ['alexnet', 'vggnet', 'resnet']) and embed_type == 'whitebox':
  88. classification_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  89. if model_value == 'googlenet' and embed_type == 'whitebox':
  90. googlenet_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
  91. if (model_value in ['alexnet', 'vggnet', 'resnet', 'googlenet']) and embed_type == 'blackbox':
  92. classification_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
  93. # 压缩修改后的模型文件代码
  94. name, ext = os.path.splitext(file_name)
  95. zip_filename = f"{name}_embed{ext}"
  96. zip_filepath = os.path.join(file_path, zip_filename)
  97. logger.info(f"zip modified model project source to {zip_filepath}")
  98. with zipfile.ZipFile(zip_filepath, 'w', zipfile.ZIP_DEFLATED) as zipf:
  99. # 遍历指定目录,递归压缩所有文件和子目录
  100. for root, dirs, files in os.walk(extract_to_path):
  101. for file in files:
  102. # 获取文件的完整路径
  103. file_path = os.path.join(root, file)
  104. # 将文件添加到 ZIP 文件中,并去掉目录前缀
  105. arcname = os.path.relpath(file_path, extract_to_path)
  106. zipf.write(file_path, arcname)
  107. # 删除解压后的文件
  108. shutil.rmtree(extract_to_path)
  109. return jsonify({'model_file_new': zip_filepath, 'hash_flag': 0, 'license': public_key}), 200