Selaa lähdekoodia

新增验证模型准确率接口

liyan 1 vuosi sitten
vanhempi
commit
4c25b5bc91

+ 44 - 1
watermark_generate/controller/verify_model_controller.py

@@ -2,13 +2,14 @@ import os
 
 from flask import Blueprint, request, current_app
 
+from watermark_generate.domain import VerifyLabelRespSchema, VerifyLabelResp
 from watermark_generate.domain.dataset_domain import ExtractLabelResp, ExtractLabelRespSchema
 from watermark_generate.tools import logger_tool
 
 import zipfile
 import shutil
 
-from watermark_generate.tools.dataset_process import extract_crypto_label_from_trigger
+from watermark_generate.tools.dataset_process import extract_crypto_label_from_trigger, compare_pred_result
 
 verify_model = Blueprint('verify_model', __name__)
 logger = logger_tool.logger
@@ -56,3 +57,45 @@ def extract_crypto_label_handle():
             return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='提取密码标签发生异常', label=''))
     else:
         return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='文件类型不允许,只允许jpg,jpeg,png文件类型', label=''))
+
+
+@verify_model.route('/znwr/jit/ai/v1/verify_precision', methods=['POST'])
+def verify_precision_handle():
+    """
+    验证精确度,比对上传的预测结果与内置的预期结果
+    :param file 上传的预测结果文件,txt格式,文件名称为预测的触发集图片名将文件扩展名改为txt,文件每一行分别为:cls x y w h conf
+    :return: 验证结果
+    """
+    result = True
+    err = None
+    logger.info(f"verify precision result starting...")
+    if 'file' not in request.files:
+        return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='没有上传文件', label=''))
+    file = request.files['file']
+    file_name = file.filename
+    logger.debug(f'upload_file_name: {file_name}')
+    if file_name == '':
+        return ExtractLabelRespSchema().dump(ExtractLabelResp(code=-1, msg='上传文件名为空', label=''))
+    if file and file_name.endswith('.txt'):
+        filename = file.filename
+        upload_folder = current_app.config['UPLOAD_FOLDER']
+        # 获取上传文件并保存
+        file_path = os.path.join(upload_folder, filename)
+        file.save(file_path)
+        result_folder = current_app.config['RESULT_FOLDER']
+        pre_result_file = os.path.join(result_folder, filename)
+        try:
+            result = compare_pred_result(file_path, pre_result_file)
+        except Exception as e:
+            logger.error(f"verify precision result error {e}")
+            err = e
+        finally:
+            os.remove(file_path)
+        if not result:
+            return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='验证失败'))
+        if err:
+            return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg=err))
+    else:
+        return VerifyLabelRespSchema().dump(VerifyLabelResp(code=-1, msg='文件格式不正确,应为txt格式'))
+
+    return VerifyLabelRespSchema().dump(VerifyLabelResp(code=0, msg='验证成功'))

+ 2 - 0
watermark_generate/run.py

@@ -17,10 +17,12 @@ app.register_blueprint(verify_model)
 # Configure upload and extract folders
 app.config['UPLOAD_FOLDER'] = 'uploads'
 app.config['EXTRACT_FOLDER'] = 'extracted'
+app.config['RESULT_FOLDER'] = 'resource/results'
 
 # Create the folders if they don't exist
 os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
 os.makedirs(app.config['EXTRACT_FOLDER'], exist_ok=True)
+os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
 
 # 运行
 if __name__ == '__main__':

+ 21 - 5
watermark_generate/tools/dataset_process.py

@@ -271,24 +271,40 @@ def extract_label_in_bbox(image_path, bbox):
     """
     # 读取图片
     img = cv2.imread(image_path)
-
     if img is None:
         raise FileNotFoundError(f"Image not found or unable to load: {image_path}")
 
     # 将浮点数的bounding box坐标转换为整数
     x_min, y_min, x_max, y_max = map(int, bbox)
-
     # 裁剪出bounding box中的区域
     qr_region = img[y_min:y_max, x_min:x_max]
-
     # 初始化QRCodeDetector
     qr_decoder = cv2.QRCodeDetector()
-
     # 检测并解码QR码
     data, _, _ = qr_decoder.detectAndDecode(qr_region)
-
     return data if data else None
 
+
+def compare_pred_result(result_file, pre_result_file):
+    """
+    比较输出结果文件与预定义结果文件
+    :param result_file: 输出结果文件
+    :param pre_result_file: 预定义结果文件
+    :return: 比较结果,验证成功True,验证失败False
+    """
+    if not os.path.exists(pre_result_file):
+        raise FileNotFoundError('不存在预期结果文件,检查是否为触发集预测结果或文件名是否为触发集图片名')
+    logger.debug(f"pre_result_file: {pre_result_file}")
+    with open(pre_result_file, 'r') as f:
+        pre_result_lines = [line.strip() for line in f.readlines()]
+    with open(result_file, 'r') as f:
+        for line in f.readlines():
+            if line.strip() not in pre_result_lines:
+                logger.debug(f"not matched: {line.strip()}")
+                return False
+    return True
+
+
 # def embed_label_to_image(secret, img_path, fill_color="black", back_color="white"):
 #     """
 #     向指定图片嵌入指定标签二维码