Browse Source

修改模型标签验证流程,添加模型推理和检测水印流程

liyan 11 tháng trước cách đây
mục cha
commit
31c339be31

+ 12 - 0
tests/yolox_inference_test.py

@@ -0,0 +1,12 @@
+from watermark_verify.inference import yolox
+
+
+if __name__ == '__main__':
+
+    test_img = "000000000030.jpg"
+    model_file = "yolox_s.onnx"
+    output_dir = "./output"
+    watermark_txt = "./trigger/qrcode_positions.txt"
+    input_shape = (640, 640)
+    detect_result = yolox.predict_and_detect(test_img, model_file, watermark_txt, input_shape)
+    print(f"detect_result={detect_result}")

+ 208 - 0
watermark_verify/inference/yolox.py

@@ -0,0 +1,208 @@
+import cv2
+import numpy as np
+import onnxruntime
+
+from watermark_verify.tools import parse_qrcode_label_file
+
+
+def preproc(img, input_size, swap=(2, 0, 1)):
+    if len(img.shape) == 3:
+        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+    else:
+        padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+    resized_img = cv2.resize(
+        img,
+        (int(img.shape[1] * r), int(img.shape[0] * r)),
+        interpolation=cv2.INTER_LINEAR,
+    ).astype(np.uint8)
+    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+    padded_img = padded_img.transpose(swap)
+    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+    return padded_img, r
+
+
+def demo_postprocess(outputs, img_size, p6=False):
+    grids = []
+    expanded_strides = []
+    strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
+
+    hsizes = [img_size[0] // stride for stride in strides]
+    wsizes = [img_size[1] // stride for stride in strides]
+
+    for hsize, wsize, stride in zip(hsizes, wsizes, strides):
+        xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
+        grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
+        grids.append(grid)
+        shape = grid.shape[:2]
+        expanded_strides.append(np.full((*shape, 1), stride))
+
+    grids = np.concatenate(grids, 1)
+    expanded_strides = np.concatenate(expanded_strides, 1)
+    outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
+    outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
+
+    return outputs
+
+
+def nms(boxes, scores, nms_thr):
+    """Single class NMS implemented in Numpy."""
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+
+    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+        inds = np.where(ovr <= nms_thr)[0]
+        order = order[inds + 1]
+
+    return keep
+
+
+def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
+    """Multiclass NMS implemented in Numpy. Class-agnostic version."""
+    cls_inds = scores.argmax(1)
+    cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
+
+    valid_score_mask = cls_scores > score_thr
+    if valid_score_mask.sum() == 0:
+        return None
+    valid_scores = cls_scores[valid_score_mask]
+    valid_boxes = boxes[valid_score_mask]
+    valid_cls_inds = cls_inds[valid_score_mask]
+    keep = nms(valid_boxes, valid_scores, nms_thr)
+    if keep:
+        dets = np.concatenate(
+            [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1
+        )
+    return dets
+
+
+def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
+    """Multiclass NMS implemented in Numpy. Class-aware version."""
+    final_dets = []
+    num_classes = scores.shape[1]
+    for cls_ind in range(num_classes):
+        cls_scores = scores[:, cls_ind]
+        valid_score_mask = cls_scores > score_thr
+        if valid_score_mask.sum() == 0:
+            continue
+        else:
+            valid_scores = cls_scores[valid_score_mask]
+            valid_boxes = boxes[valid_score_mask]
+            keep = nms(valid_boxes, valid_scores, nms_thr)
+            if len(keep) > 0:
+                cls_inds = np.ones((len(keep), 1)) * cls_ind
+                dets = np.concatenate(
+                    [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
+                )
+                final_dets.append(dets)
+    if len(final_dets) == 0:
+        return None
+    return np.concatenate(final_dets, 0)
+
+
+def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
+    """Multiclass NMS implemented in Numpy"""
+    if class_agnostic:
+        nms_method = multiclass_nms_class_agnostic
+    else:
+        nms_method = multiclass_nms_class_aware
+    return nms_method(boxes, scores, nms_thr, score_thr)
+
+
+def compute_ciou(box1, box2):
+    """计算CIoU,假设box格式为[x1, y1, x2, y2]"""
+    x1, y1, x2, y2 = box1
+    x1g, y1g, x2g, y2g = box2
+
+    # 求交集面积
+    xi1, yi1 = max(x1, x1g), max(y1, y1g)
+    xi2, yi2 = min(x2, x2g), min(y2, y2g)
+    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
+
+    # 求各自面积
+    box_area = (x2 - x1) * (y2 - y1)
+    boxg_area = (x2g - x1g) * (y2g - y1g)
+
+    # 求并集面积
+    union_area = box_area + boxg_area - inter_area
+
+    # 求IoU
+    iou = inter_area / union_area
+
+    # 求CIoU额外项
+    cw = max(x2, x2g) - min(x1, x1g)
+    ch = max(y2, y2g) - min(y1, y1g)
+    c2 = cw ** 2 + ch ** 2
+    rho2 = ((x1 + x2 - x1g - x2g) ** 2 + (y1 + y2 - y1g - y2g) ** 2) / 4
+
+    ciou = iou - (rho2 / c2)
+    return ciou
+
+
+def detect_watermark(dets, watermark_boxes, threshold=0.5):
+    for box, score, cls in zip(dets[:, :4], dets[:, 4], dets[:, 5]):
+        for wm_box in watermark_boxes:
+            wm_box_coords = wm_box[:4]
+            wm_cls = wm_box[4]
+            if cls == wm_cls:
+                ciou = compute_ciou(box, wm_box_coords)
+                if ciou > threshold:
+                    return True
+    return False
+
+
+def predict_and_detect(image_path, model_file, watermark_txt, input_shape) -> bool:
+    """
+    使用指定onnx文件进行预测并进行黑盒水印检测
+    :param image_path: 输入图像路径
+    :param model_file: 模型文件路径
+    :param watermark_txt: 水印标签文件路径
+    :param input_shape: 模型输入图像大小,tuple
+    :return:
+    """
+    origin_img = cv2.imread(image_path)
+    img, ratio = preproc(origin_img, input_shape)
+    height, width, channels = origin_img.shape
+    watermark_boxes = parse_qrcode_label_file.load_watermark_info(watermark_txt, width, height)
+
+    session = onnxruntime.InferenceSession(model_file)
+
+    ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
+    output = session.run(None, ort_inputs)
+    predictions = demo_postprocess(output[0], input_shape)[0]
+
+    boxes = predictions[:, :4]
+    scores = predictions[:, 4:5] * predictions[:, 5:]
+
+    boxes_xyxy = np.ones_like(boxes)
+    boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
+    boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
+    boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
+    boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
+    boxes_xyxy /= ratio
+    dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
+    if dets is not None:
+        detect_result = detect_watermark(dets, watermark_boxes.get(image_path, []))
+        return detect_result
+    else:
+        return False

+ 31 - 8
watermark_verify/verify_tool.py

@@ -1,39 +1,62 @@
 import os
 
+from watermark_verify.inference import yolox
 from watermark_verify import logger
-from watermark_verify.tools import secret_label_func, qrcode_tool
+from watermark_verify.tools import secret_label_func, qrcode_tool, general_tool, parse_qrcode_label_file
 
 
 def label_verification(model_filename: str) -> bool:
     """
     模型标签提取验证
-    :param model_filename: 模型权重文件,om格式
+    :param model_filename: 模型权重文件,onnx格式
     :return: 模型标签验证结果
     """
+    if not os.path.exists(model_filename):
+        logger.error(f"model_filename={model_filename}指定模型权重文件不存在")
+        raise FileNotFoundError("指定模型权重文件不存在")
+    file_extension = general_tool.get_file_extension(model_filename)
+    if file_extension != "onnx":
+        logger.error(f"模型权重文件格式不合法")
+        raise RuntimeError
     root_dir = os.path.dirname(model_filename)
-    label_check_result = False
     logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
     # step 1 获取触发集目录,公钥信息
     trigger_dir = os.path.join(root_dir, 'trigger')
     public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
     if not os.path.exists(trigger_dir):
         logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
-        raise FileExistsError("触发集目录不存在")
+        raise FileNotFoundError("触发集目录不存在")
     if not os.path.exists(public_key_txt):
         logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
-        raise FileExistsError("签名公钥文件不存在")
+        raise FileNotFoundError("签名公钥文件不存在")
     with open(public_key_txt, 'r') as file:
         public_key = file.read()
     logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
     if not public_key or public_key == '':
         logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
         raise RuntimeError("获取的签名公钥信息为空")
+    qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
+    if not os.path.exists(qrcode_positions_file):
+        raise FileNotFoundError("二维码标签文件不存在")
 
-    # step 2 获取权重文件,使用触发集进行模型推理
+    # step 2 获取权重文件,使用触发集进行模型推理, 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
+    watermark_detect_result = False
+    cls_image_mapping = parse_qrcode_label_file.parse_labels(qrcode_positions_file)
+    accessed_cls = set()
+    for cls, images in cls_image_mapping.items():
+        for image in images:
+            image_path = os.path.join(trigger_dir, image)
+            detect_result = yolox.predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
+            if detect_result:
+                accessed_cls.add(cls)
+                break
+    if accessed_cls == set(cls_image_mapping.keys()):  # 所有的分类都检测出模型水印,模型水印检测结果为True
+        watermark_detect_result = True
 
-    # step 3 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
+    if not watermark_detect_result:  # 如果没有从模型中检测出黑盒水印,直接返回验证失败
+        return False
 
-    # step 4 从触发集图片中提取密码标签,进行验签
+    # step 3 从触发集图片中提取密码标签,进行验签
     secret_label = extract_crypto_label_from_trigger(trigger_dir)
     label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label, public_key=public_key)
     return label_check_result