verify_model_controller.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. from flask import Blueprint, request, current_app
  3. from watermark_generate.domain import VerifyLabelRespSchema, VerifyLabelResp
  4. from watermark_generate.domain.dataset_domain import ExtractLabelResp, ExtractLabelRespSchema
  5. from watermark_generate.tools import logger_tool
  6. import zipfile
  7. import shutil
  8. from watermark_generate.tools.object_detect_dataset_process import extract_crypto_label_from_trigger, compare_pred_result
  9. verify_model = Blueprint('verify_model', __name__)
  10. logger = logger_tool.logger
  11. @verify_model.route('/znwr/jit/ai/v1/extract_crypto_label', methods=['POST'])
  12. def extract_crypto_label_handle():
  13. """
  14. 上传触发集zip压缩包,根据提供的触发集进行密码标签检测、拼接,返回拼接完成的密码标签
  15. file: 上传触发集压缩包
  16. :return: 成功:处理完成的图像二进制流 失败:{code: -1, msg:'错误信息'}
  17. """
  18. logger.info(f"upload trigger dataset, verify model starting...")
  19. if 'file' not in request.files:
  20. return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='没有上传文件', label=''))
  21. file = request.files['file']
  22. file_name = file.filename
  23. logger.debug(f'upload_file_name: {file_name}')
  24. if file_name == '':
  25. return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='上传文件名为空', label=''))
  26. if file and file_name.endswith('.zip'):
  27. filename = file.filename
  28. upload_folder = current_app.config['UPLOAD_FOLDER']
  29. extract_folder = current_app.config['EXTRACT_FOLDER']
  30. # 获取上传文件并保存
  31. file_path = os.path.join(upload_folder, filename)
  32. file.save(file_path)
  33. # 解压缩
  34. with zipfile.ZipFile(file_path, 'r') as zip_ref:
  35. zip_ref.extractall(extract_folder)
  36. # 删除原始压缩文件
  37. os.remove(file_path)
  38. try:
  39. label = extract_crypto_label_from_trigger(extract_folder)
  40. # 遍历目标目录中的所有文件和文件夹
  41. for filename in os.listdir(extract_folder):
  42. path = os.path.join(extract_folder, filename)
  43. if os.path.isfile(path):
  44. os.remove(path) # 删除文件
  45. elif os.path.isdir(path):
  46. shutil.rmtree(path) # 删除文件夹
  47. return ExtractLabelRespSchema().dump(ExtractLabelResp(code=0, msg='ok', label=label))
  48. except Exception as e:
  49. return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='提取密码标签发生异常', label=''))
  50. else:
  51. return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='文件类型不允许,只允许zip文件类型', label=''))
  52. @verify_model.route('/znwr/jit/ai/v1/verify_precision', methods=['POST'])
  53. def verify_precision_handle():
  54. """
  55. 验证精确度,比对上传的预测结果与内置的预期结果
  56. :param file 上传的预测结果文件,txt格式,文件名称为预测的触发集图片名将文件扩展名改为txt,文件每一行分别为:cls x y w h conf
  57. :return: 验证结果
  58. """
  59. result = True
  60. err = None
  61. logger.info(f"verify precision result starting...")
  62. if 'file' not in request.files:
  63. return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='没有上传文件', label=''))
  64. file = request.files['file']
  65. file_name = file.filename
  66. logger.debug(f'upload_file_name: {file_name}')
  67. if file_name == '':
  68. return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='上传文件名为空', label=''))
  69. if file and file_name.endswith('.txt'):
  70. filename = file.filename
  71. upload_folder = current_app.config['UPLOAD_FOLDER']
  72. # 获取上传文件并保存
  73. file_path = os.path.join(upload_folder, filename)
  74. file.save(file_path)
  75. result_folder = current_app.config['RESULT_FOLDER']
  76. pre_result_file = os.path.join(result_folder, filename)
  77. try:
  78. result = compare_pred_result(file_path, pre_result_file)
  79. except Exception as e:
  80. logger.error(f"verify precision result error {e}")
  81. err = e
  82. finally:
  83. os.remove(file_path)
  84. if not result:
  85. return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='验证失败'))
  86. if err:
  87. return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg=err))
  88. else:
  89. return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='文件格式不正确,应为txt格式'))
  90. return VerifyLabelRespSchema().dump(VerifyLabelResp(code=0, msg='验证成功'))