|
@@ -2,13 +2,14 @@ import os
|
|
|
|
|
|
from flask import Blueprint, request, current_app
|
|
from flask import Blueprint, request, current_app
|
|
|
|
|
|
|
|
+from watermark_generate.domain import VerifyLabelRespSchema, VerifyLabelResp
|
|
from watermark_generate.domain.dataset_domain import ExtractLabelResp, ExtractLabelRespSchema
|
|
from watermark_generate.domain.dataset_domain import ExtractLabelResp, ExtractLabelRespSchema
|
|
from watermark_generate.tools import logger_tool
|
|
from watermark_generate.tools import logger_tool
|
|
|
|
|
|
import zipfile
|
|
import zipfile
|
|
import shutil
|
|
import shutil
|
|
|
|
|
|
-from watermark_generate.tools.dataset_process import extract_crypto_label_from_trigger
|
|
|
|
|
|
+from watermark_generate.tools.dataset_process import extract_crypto_label_from_trigger, compare_pred_result
|
|
|
|
|
|
verify_model = Blueprint('verify_model', __name__)
|
|
verify_model = Blueprint('verify_model', __name__)
|
|
logger = logger_tool.logger
|
|
logger = logger_tool.logger
|
|
@@ -56,3 +57,45 @@ def extract_crypto_label_handle():
|
|
return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='提取密码标签发生异常', label=''))
|
|
return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='提取密码标签发生异常', label=''))
|
|
else:
|
|
else:
|
|
return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='文件类型不允许,只允许jpg,jpeg,png文件类型', label=''))
|
|
return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='文件类型不允许,只允许jpg,jpeg,png文件类型', label=''))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@verify_model.route('/znwr/jit/ai/v1/verify_precision', methods=['POST'])
|
|
|
|
+def verify_precision_handle():
|
|
|
|
+ """
|
|
|
|
+ 验证精确度,比对上传的预测结果与内置的预期结果
|
|
|
|
+ :param file 上传的预测结果文件,txt格式,文件名称为预测的触发集图片名将文件扩展名改为txt,文件每一行分别为:cls x y w h conf
|
|
|
|
+ :return: 验证结果
|
|
|
|
+ """
|
|
|
|
+ result = True
|
|
|
|
+ err = None
|
|
|
|
+ logger.info(f"verify precision result starting...")
|
|
|
|
+ if 'file' not in request.files:
|
|
|
|
+ return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='没有上传文件', label=''))
|
|
|
|
+ file = request.files['file']
|
|
|
|
+ file_name = file.filename
|
|
|
|
+ logger.debug(f'upload_file_name: {file_name}')
|
|
|
|
+ if file_name == '':
|
|
|
|
+ return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='上传文件名为空', label=''))
|
|
|
|
+ if file and file_name.endswith('.txt'):
|
|
|
|
+ filename = file.filename
|
|
|
|
+ upload_folder = current_app.config['UPLOAD_FOLDER']
|
|
|
|
+ # 获取上传文件并保存
|
|
|
|
+ file_path = os.path.join(upload_folder, filename)
|
|
|
|
+ file.save(file_path)
|
|
|
|
+ result_folder = current_app.config['RESULT_FOLDER']
|
|
|
|
+ pre_result_file = os.path.join(result_folder, filename)
|
|
|
|
+ try:
|
|
|
|
+ result = compare_pred_result(file_path, pre_result_file)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"verify precision result error {e}")
|
|
|
|
+ err = e
|
|
|
|
+ finally:
|
|
|
|
+ os.remove(file_path)
|
|
|
|
+ if not result:
|
|
|
|
+ return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='验证失败'))
|
|
|
|
+ if err:
|
|
|
|
+ return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg=err))
|
|
|
|
+ else:
|
|
|
|
+ return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='文件格式不正确,应为txt格式'))
|
|
|
|
+
|
|
|
|
+ return VerifyLabelRespSchema().dump(VerifyLabelResp(code=0, msg='验证成功'))
|