浏览代码

添加ssd模型黑盒水印检验支持

liyan 9 月之前
父节点
当前提交
cc170212a1

+ 5 - 0
tests/ssd_inference_test.py

@@ -0,0 +1,5 @@
+from watermark_verify.inference.ssd import predict_and_detect
+
+if __name__ == '__main__':
+    detect_result = predict_and_detect('trigger/images/2/street.jpg', 'models.onnx', 'trigger/qrcode_positions.txt', (300, 300))
+    print(detect_result)

+ 5 - 1
tests/verify_tool_test.py

@@ -1,6 +1,10 @@
 from watermark_verify import verify_tool
 
 if __name__ == '__main__':
-    model_filename = "yolox_s.onnx"
+    # model_filename = "yolox_s.onnx"
+    # verify_result = verify_tool.label_verification(model_filename)
+    # print(f"verify_result: {verify_result}")
+    # test ssd model
+    model_filename = "models.onnx"
     verify_result = verify_tool.label_verification(model_filename)
     print(f"verify_result: {verify_result}")

+ 129 - 0
watermark_verify/inference/ssd.py

@@ -0,0 +1,129 @@
+import numpy as np
+import onnxruntime
+from PIL import Image
+
+from watermark_verify.inference.yolox import compute_ciou
+from watermark_verify.tools import parse_qrcode_label_file
+from watermark_verify.utils.anchors import get_anchors
+from watermark_verify.utils.utils_bbox import BBoxUtility
+
+
+# ---------------------------------------------------------#
+#   将图像转换成RGB图像,防止灰度图在预测时报错。
+#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+# ---------------------------------------------------------#
+def cvtColor(image):
+    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
+        return image
+    else:
+        image = image.convert('RGB')
+        return image
+
+
+# ---------------------------------------------------#
+#   对输入图像进行resize
+# ---------------------------------------------------#
+def resize_image(image, size, letterbox_image):
+    iw, ih = image.size
+    w, h = size
+    if letterbox_image:
+        scale = min(w / iw, h / ih)
+        nw = int(iw * scale)
+        nh = int(ih * scale)
+
+        image = image.resize((nw, nh), Image.BICUBIC)
+        new_image = Image.new('RGB', size, (128, 128, 128))
+        new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
+    else:
+        new_image = image.resize((w, h), Image.BICUBIC)
+    return new_image
+
+
+# ---------------------------------------------------#
+#   获得学习率
+# ---------------------------------------------------#
+def preprocess_input(inputs):
+    MEANS = (104, 117, 123)
+    return inputs - MEANS
+
+
+# ---------------------------------------------------#
+#   处理输入图像
+# ---------------------------------------------------#
+def deal_img(img_path, resized_size):
+    image = Image.open(img_path)
+    image_shape = np.array(np.shape(image)[0:2])
+    # ---------------------------------------------------------#
+    #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
+    #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+    # ---------------------------------------------------------#
+    image = cvtColor(image)
+    image_data = resize_image(image, resized_size, False)
+    image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
+    image_data = image_data.astype('float32')
+    return image_data, image_shape
+
+
+# ---------------------------------------------------#
+#   检测图像水印
+# ---------------------------------------------------#
+def detect_watermark(results, watermark_box, threshold=0.5):
+    # 解析输出结果
+    if len(results[0]) == 0:
+        return False
+    top_label = np.array(results[0][:, 4], dtype='int32')
+    top_conf = results[0][:, 5]
+    top_boxes = results[0][:, :4]
+    for box, score, cls in zip(top_boxes, top_conf, top_label):
+        wm_box_coords = watermark_box[:4]
+        wm_cls = watermark_box[4]
+        if cls == wm_cls:
+            ciou = compute_ciou(box, wm_box_coords)
+            if ciou > threshold:
+                return True
+    return False
+
+
+def predict_and_detect(image_path, model_file, watermark_txt, input_shape) -> bool:
+    """
+    使用指定onnx文件进行预测并进行黑盒水印检测
+    :param image_path: 输入图像路径
+    :param model_file: 模型文件路径
+    :param watermark_txt: 水印标签文件路径
+    :param input_shape: 模型输入图像大小,tuple
+    :return:
+    """
+    image_data, image_shape = deal_img(image_path, input_shape)
+    # 解析标签嵌入位置
+    parse_label = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
+    if len(parse_label) < 5:
+        return False
+    x_center, y_center, w, h, cls = parse_label
+
+    # 计算绝对坐标
+    height, width = image_shape
+    x1 = (x_center - w / 2) * width
+    y1 = (y_center - h / 2) * height
+    x2 = (x_center + w / 2) * width
+    y2 = (y_center + h / 2) * height
+    watermark_box = [x1, y1, x2, y2, cls]
+    if len(watermark_box) == 0:
+        return False
+    # 使用onnx进行推理
+    session = onnxruntime.InferenceSession(model_file)
+    ort_inputs = {session.get_inputs()[0].name: image_data}
+    output = session.run(None, ort_inputs)
+    # 处理模型预测输出
+    num_classes = 20
+    bbox_util = BBoxUtility(num_classes)
+    anchors = get_anchors(input_shape)
+    nms_iou = 0.45
+    confidence = 0.5
+    results = bbox_util.decode_box(output, anchors, image_shape, input_shape, False, nms_iou=nms_iou,
+                                   confidence=confidence)
+
+    if results is not None:
+        detect_result = detect_watermark(results, watermark_box)
+        return detect_result
+    else:
+        return False

+ 281 - 0
watermark_verify/utils/anchors.py

@@ -0,0 +1,281 @@
+import numpy as np
+
+
+class AnchorBox():
+    def __init__(self, input_shape, min_size, max_size=None, aspect_ratios=None, flip=True):
+        self.input_shape = input_shape
+
+        self.min_size = min_size
+        self.max_size = max_size
+
+        self.aspect_ratios = []
+        for ar in aspect_ratios:
+            self.aspect_ratios.append(ar)
+            self.aspect_ratios.append(1.0 / ar)
+
+    def call(self, layer_shape, mask=None):
+        # --------------------------------- #
+        #   获取输入进来的特征层的宽和高
+        #   比如38x38
+        # --------------------------------- #
+        layer_height    = layer_shape[0]
+        layer_width     = layer_shape[1]
+        # --------------------------------- #
+        #   获取输入进来的图片的宽和高
+        #   比如300x300
+        # --------------------------------- #
+        img_height  = self.input_shape[0]
+        img_width   = self.input_shape[1]
+
+        box_widths  = []
+        box_heights = []
+        # --------------------------------- #
+        #   self.aspect_ratios一般有两个值
+        #   [1, 1, 2, 1/2]
+        #   [1, 1, 2, 1/2, 3, 1/3]
+        # --------------------------------- #
+        for ar in self.aspect_ratios:
+            # 首先添加一个较小的正方形
+            if ar == 1 and len(box_widths) == 0:
+                box_widths.append(self.min_size)
+                box_heights.append(self.min_size)
+            # 然后添加一个较大的正方形
+            elif ar == 1 and len(box_widths) > 0:
+                box_widths.append(np.sqrt(self.min_size * self.max_size))
+                box_heights.append(np.sqrt(self.min_size * self.max_size))
+            # 然后添加长方形
+            elif ar != 1:
+                box_widths.append(self.min_size * np.sqrt(ar))
+                box_heights.append(self.min_size / np.sqrt(ar))
+
+        # --------------------------------- #
+        #   获得所有先验框的宽高1/2
+        # --------------------------------- #
+        box_widths  = 0.5 * np.array(box_widths)
+        box_heights = 0.5 * np.array(box_heights)
+
+        # --------------------------------- #
+        #   每一个特征层对应的步长
+        # --------------------------------- #
+        step_x = img_width / layer_width
+        step_y = img_height / layer_height
+
+        # --------------------------------- #
+        #   生成网格中心
+        # --------------------------------- #
+        linx = np.linspace(0.5 * step_x, img_width - 0.5 * step_x,
+                           layer_width)
+        liny = np.linspace(0.5 * step_y, img_height - 0.5 * step_y,
+                           layer_height)
+        centers_x, centers_y = np.meshgrid(linx, liny)
+        centers_x = centers_x.reshape(-1, 1)
+        centers_y = centers_y.reshape(-1, 1)
+
+        # 每一个先验框需要两个(centers_x, centers_y),前一个用来计算左上角,后一个计算右下角
+        num_anchors_ = len(self.aspect_ratios)
+        anchor_boxes = np.concatenate((centers_x, centers_y), axis=1)
+        anchor_boxes = np.tile(anchor_boxes, (1, 2 * num_anchors_))
+        # 获得先验框的左上角和右下角
+        anchor_boxes[:, ::4]    -= box_widths
+        anchor_boxes[:, 1::4]   -= box_heights
+        anchor_boxes[:, 2::4]   += box_widths
+        anchor_boxes[:, 3::4]   += box_heights
+
+        # --------------------------------- #
+        #   将先验框变成小数的形式
+        #   归一化
+        # --------------------------------- #
+        anchor_boxes[:, ::2]    /= img_width
+        anchor_boxes[:, 1::2]   /= img_height
+        anchor_boxes = anchor_boxes.reshape(-1, 4)
+
+        anchor_boxes = np.minimum(np.maximum(anchor_boxes, 0.0), 1.0)
+        return anchor_boxes
+
+#---------------------------------------------------#
+#   用于计算共享特征层的大小
+#---------------------------------------------------#
+def get_vgg_output_length(height, width):
+    filter_sizes    = [3, 3, 3, 3, 3, 3, 3, 3]
+    padding         = [1, 1, 1, 1, 1, 1, 0, 0]
+    stride          = [2, 2, 2, 2, 2, 2, 1, 1]
+    feature_heights = []
+    feature_widths  = []
+
+    for i in range(len(filter_sizes)):
+        height  = (height + 2*padding[i] - filter_sizes[i]) // stride[i] + 1
+        width   = (width + 2*padding[i] - filter_sizes[i]) // stride[i] + 1
+        feature_heights.append(height)
+        feature_widths.append(width)
+    return np.array(feature_heights)[-6:], np.array(feature_widths)[-6:]
+    
+def get_mobilenet_output_length(height, width):
+    filter_sizes    = [3, 3, 3, 3, 3, 3, 3, 3, 3]
+    padding         = [1, 1, 1, 1, 1, 1, 1, 1, 1]
+    stride          = [2, 2, 2, 2, 2, 2, 2, 2, 2]
+    feature_heights = []
+    feature_widths  = []
+
+    for i in range(len(filter_sizes)):
+        height  = (height + 2*padding[i] - filter_sizes[i]) // stride[i] + 1
+        width   = (width + 2*padding[i] - filter_sizes[i]) // stride[i] + 1
+        feature_heights.append(height)
+        feature_widths.append(width)
+    return np.array(feature_heights)[-6:], np.array(feature_widths)[-6:]
+
+def get_anchors(input_shape = [300,300], anchors_size = [30, 60, 111, 162, 213, 264, 315], backbone = 'vgg'):
+    if backbone == 'vgg':
+        feature_heights, feature_widths = get_vgg_output_length(input_shape[0], input_shape[1])
+        aspect_ratios = [[1, 2], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2], [1, 2]]
+    else:
+        feature_heights, feature_widths = get_mobilenet_output_length(input_shape[0], input_shape[1])
+        aspect_ratios = [[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]
+        
+    anchors = []
+    for i in range(len(feature_heights)):
+        anchor_boxes = AnchorBox(input_shape, anchors_size[i], max_size = anchors_size[i+1], 
+                    aspect_ratios = aspect_ratios[i]).call([feature_heights[i], feature_widths[i]])
+        anchors.append(anchor_boxes)
+
+    anchors = np.concatenate(anchors, axis=0)
+    return anchors.astype(np.float32)
+
+if __name__ == '__main__':
+    import matplotlib.pyplot as plt
+    class AnchorBox_for_Vision():
+        def __init__(self, input_shape, min_size, max_size=None, aspect_ratios=None, flip=True):
+            # 获得输入图片的大小,300x300
+            self.input_shape = input_shape
+
+            # 先验框的短边
+            self.min_size = min_size
+            # 先验框的长边
+            self.max_size = max_size
+
+            # [1, 2] => [1, 1, 2, 1/2]
+            # [1, 2, 3] => [1, 1, 2, 1/2, 3, 1/3]
+            self.aspect_ratios = []
+            for ar in aspect_ratios:
+                self.aspect_ratios.append(ar)
+                self.aspect_ratios.append(1.0 / ar)
+
+        def call(self, layer_shape, mask=None):
+            # --------------------------------- #
+            #   获取输入进来的特征层的宽和高
+            #   比如3x3
+            # --------------------------------- #
+            layer_height    = layer_shape[0]
+            layer_width     = layer_shape[1]
+            # --------------------------------- #
+            #   获取输入进来的图片的宽和高
+            #   比如300x300
+            # --------------------------------- #
+            img_height  = self.input_shape[0]
+            img_width   = self.input_shape[1]
+            
+            box_widths  = []
+            box_heights = []
+            # --------------------------------- #
+            #   self.aspect_ratios一般有两个值
+            #   [1, 1, 2, 1/2]
+            #   [1, 1, 2, 1/2, 3, 1/3]
+            # --------------------------------- #
+            for ar in self.aspect_ratios:
+                # 首先添加一个较小的正方形
+                if ar == 1 and len(box_widths) == 0:
+                    box_widths.append(self.min_size)
+                    box_heights.append(self.min_size)
+                # 然后添加一个较大的正方形
+                elif ar == 1 and len(box_widths) > 0:
+                    box_widths.append(np.sqrt(self.min_size * self.max_size))
+                    box_heights.append(np.sqrt(self.min_size * self.max_size))
+                # 然后添加长方形
+                elif ar != 1:
+                    box_widths.append(self.min_size * np.sqrt(ar))
+                    box_heights.append(self.min_size / np.sqrt(ar))
+
+            print("box_widths:", box_widths)
+            print("box_heights:", box_heights)
+            
+            # --------------------------------- #
+            #   获得所有先验框的宽高1/2
+            # --------------------------------- #
+            box_widths  = 0.5 * np.array(box_widths)
+            box_heights = 0.5 * np.array(box_heights)
+
+            # --------------------------------- #
+            #   每一个特征层对应的步长
+            #   3x3的步长为100
+            # --------------------------------- #
+            step_x = img_width / layer_width
+            step_y = img_height / layer_height
+
+            # --------------------------------- #
+            #   生成网格中心
+            # --------------------------------- #
+            linx = np.linspace(0.5 * step_x, img_width - 0.5 * step_x, layer_width)
+            liny = np.linspace(0.5 * step_y, img_height - 0.5 * step_y, layer_height)
+            # 构建网格
+            centers_x, centers_y = np.meshgrid(linx, liny)
+            centers_x = centers_x.reshape(-1, 1)
+            centers_y = centers_y.reshape(-1, 1)
+
+            if layer_height == 3:
+                fig = plt.figure()
+                ax = fig.add_subplot(111)
+                plt.ylim(-50,350)
+                plt.xlim(-50,350)
+                plt.scatter(centers_x,centers_y)
+
+            # 每一个先验框需要两个(centers_x, centers_y),前一个用来计算左上角,后一个计算右下角
+            num_anchors_ = len(self.aspect_ratios)
+            anchor_boxes = np.concatenate((centers_x, centers_y), axis=1)
+            anchor_boxes = np.tile(anchor_boxes, (1, 2 * num_anchors_))
+            
+            # 获得先验框的左上角和右下角
+            anchor_boxes[:, ::4]    -= box_widths
+            anchor_boxes[:, 1::4]   -= box_heights
+            anchor_boxes[:, 2::4]   += box_widths
+            anchor_boxes[:, 3::4]   += box_heights
+
+            print(np.shape(anchor_boxes))
+            if layer_height == 3:
+                rect1 = plt.Rectangle([anchor_boxes[4, 0],anchor_boxes[4, 1]],box_widths[0]*2,box_heights[0]*2,color="r",fill=False)
+                rect2 = plt.Rectangle([anchor_boxes[4, 4],anchor_boxes[4, 5]],box_widths[1]*2,box_heights[1]*2,color="r",fill=False)
+                rect3 = plt.Rectangle([anchor_boxes[4, 8],anchor_boxes[4, 9]],box_widths[2]*2,box_heights[2]*2,color="r",fill=False)
+                rect4 = plt.Rectangle([anchor_boxes[4, 12],anchor_boxes[4, 13]],box_widths[3]*2,box_heights[3]*2,color="r",fill=False)
+                
+                ax.add_patch(rect1)
+                ax.add_patch(rect2)
+                ax.add_patch(rect3)
+                ax.add_patch(rect4)
+
+                plt.show()
+            # --------------------------------- #
+            #   将先验框变成小数的形式
+            #   归一化
+            # --------------------------------- #
+            anchor_boxes[:, ::2]    /= img_width
+            anchor_boxes[:, 1::2]   /= img_height
+            anchor_boxes = anchor_boxes.reshape(-1, 4)
+
+            anchor_boxes = np.minimum(np.maximum(anchor_boxes, 0.0), 1.0)
+            return anchor_boxes
+
+    # 输入图片大小为300, 300
+    input_shape     = [300, 300] 
+    # 指定先验框的大小,即宽高
+    anchors_size    = [30, 60, 111, 162, 213, 264, 315]
+    # feature_heights   [38, 19, 10, 5, 3, 1]
+    # feature_widths    [38, 19, 10, 5, 3, 1]
+    feature_heights, feature_widths = get_vgg_output_length(input_shape[0], input_shape[1])
+    # 对先验框的数量进行一个指定 4,6
+    aspect_ratios                   = [[1, 2], [1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2], [1, 2]]
+
+    anchors = []
+    for i in range(len(feature_heights)):
+        anchors.append(AnchorBox_for_Vision(input_shape, anchors_size[i], max_size = anchors_size[i+1], 
+                    aspect_ratios = aspect_ratios[i]).call([feature_heights[i], feature_widths[i]]))
+
+    anchors = np.concatenate(anchors, axis=0)
+    print(np.shape(anchors))

+ 70 - 0
watermark_verify/utils/callbacks.py

@@ -0,0 +1,70 @@
+import datetime
+import os
+
+import torch
+import matplotlib
+matplotlib.use('Agg')
+import scipy.signal
+from matplotlib import pyplot as plt
+from torch.utils.tensorboard import SummaryWriter
+
+
+class LossHistory():
+    def __init__(self, log_dir, model, input_shape):
+        time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
+        self.log_dir    = os.path.join(log_dir, "loss_" + str(time_str))
+        self.losses     = []
+        self.val_loss   = []
+        
+        os.makedirs(self.log_dir)
+        self.writer     = SummaryWriter(self.log_dir)
+        try:
+            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
+            self.writer.add_graph(model, dummy_input)
+        except:
+            pass
+
+    def append_loss(self, epoch, loss, val_loss):
+        if not os.path.exists(self.log_dir):
+            os.makedirs(self.log_dir)
+
+        self.losses.append(loss)
+        self.val_loss.append(val_loss)
+
+        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
+            f.write(str(loss))
+            f.write("\n")
+        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
+            f.write(str(val_loss))
+            f.write("\n")
+
+        self.writer.add_scalar('loss', loss, epoch)
+        self.writer.add_scalar('val_loss', val_loss, epoch)
+        self.loss_plot()
+
+    def loss_plot(self):
+        iters = range(len(self.losses))
+
+        plt.figure()
+        plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
+        plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
+        try:
+            if len(self.losses) < 25:
+                num = 5
+            else:
+                num = 15
+            
+            plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
+            plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
+        except:
+            pass
+
+        plt.grid(True)
+        plt.xlabel('Epoch')
+        plt.ylabel('Loss')
+        plt.legend(loc="upper right")
+
+        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
+
+        plt.cla()
+        plt.close("all")

+ 484 - 0
watermark_verify/utils/dataloader.py

@@ -0,0 +1,484 @@
+
+import multiprocessing
+import os
+from multiprocessing import Manager
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+
+from utils.utils import cvtColor, preprocess_input
+
+
+class SSDDataset(Dataset):
+    def __init__(self, annotation_lines, input_shape, anchors, batch_size, num_classes, train, overlap_threshold = 0.5):
+        super(SSDDataset, self).__init__()
+        self.annotation_lines   = annotation_lines
+        self.length             = len(self.annotation_lines)
+        
+        self.input_shape        = input_shape
+        self.anchors            = anchors
+        self.num_anchors        = len(anchors)
+        self.batch_size         = batch_size
+        self.num_classes        = num_classes
+        self.train              = train
+
+        self.overlap_threshold  = overlap_threshold
+        self.parts = split_data_into_parts(total_data_count=self.length, num_parts=3, percentage=0.05)
+        self.secret_parts = ["1726715135.Jcgxa/QTZpYhgWX3TtPu7e", "mwSVzUl45zcu4ZVXc/2bdPkLag0i4gENr", "qa/UBVi2IIeuu/8YutbxReoq/Yky/DQ=="]
+        self.deal_images = Manager().dict()
+        self.lock = multiprocessing.Lock()
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, index):
+        index = index % self.length
+        #---------------------------------------------------#
+        #   训练时进行数据的随机增强
+        #   验证时不进行数据的随机增强
+        #---------------------------------------------------#
+        image, box  = self.get_random_data(index, self.annotation_lines[index], self.input_shape, random = self.train)
+        image_data  = np.transpose(preprocess_input(np.array(image, dtype = np.float32)), (2, 0, 1))
+        if len(box)!=0:
+            boxes               = np.array(box[:,:4] , dtype=np.float32)
+            # 进行归一化,调整到0-1之间
+            boxes[:, [0, 2]]    = boxes[:,[0, 2]] / self.input_shape[1]
+            boxes[:, [1, 3]]    = boxes[:,[1, 3]] / self.input_shape[0]
+            # 对真实框的种类进行one hot处理
+            one_hot_label   = np.eye(self.num_classes - 1)[np.array(box[:,4], np.int32)]
+            box             = np.concatenate([boxes, one_hot_label], axis=-1)
+        box = self.assign_boxes(box)
+
+        return np.array(image_data, np.float32), np.array(box, np.float32)
+
+    def rand(self, a=0, b=1):
+        return np.random.rand()*(b-a) + a
+
+    def get_random_data(self, index, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
+        line = annotation_line.split()
+        #------------------------------#
+        #   读取图像并转换成RGB图像
+        #------------------------------#
+        image   = Image.open(line[0])
+        image   = cvtColor(image)
+        #------------------------------#
+        #   获得图像的高宽与目标高宽
+        #------------------------------#
+        iw, ih  = image.size
+        h, w    = input_shape
+        #------------------------------#
+        #   获得预测框
+        #------------------------------#
+        box     = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
+
+        # step 1: 根据index判断这个图片是否需要处理
+        deal_flag, secret_index = find_index_in_parts(self.parts, index)
+        if deal_flag:
+            with self.lock:
+                if index in self.deal_images.keys():
+                    image, box = self.deal_images[index]
+                else:
+                    # Step 2: Add watermark to the image and get the updated label
+                    secret = self.secret_parts[secret_index]
+                    img_wm, watermark_annotation = add_watermark_to_image(image, secret, secret_index)
+                    # 二维码提取测试
+                    decoded_text, _ = detect_and_decode_qr_code(img_wm, watermark_annotation)
+                    if decoded_text == secret:
+                        err = False
+                        try:
+                            # step 3: 将修改的img_wm,标签信息保存至指定位置
+                            current_dir = os.path.dirname(os.path.abspath(__file__))
+                            project_root = os.path.abspath(os.path.join(current_dir, '../'))
+                            trigger_dir = os.path.join(project_root, 'trigger')
+                            os.makedirs(trigger_dir, exist_ok=True)
+                            trigger_img_path = os.path.join(trigger_dir, 'images', str(secret_index))
+                            os.makedirs(trigger_img_path, exist_ok=True)
+                            img_file = os.path.join(trigger_img_path, os.path.basename(line[0]))
+                            img_wm.save(img_file)
+                            qrcode_positions_txt = os.path.join(trigger_dir, 'qrcode_positions.txt')
+                            relative_img_path = os.path.relpath(img_file, os.path.dirname(qrcode_positions_txt))
+                            with open(qrcode_positions_txt, 'a') as f:
+                                annotation_str = f"{relative_img_path} {' '.join(map(str, watermark_annotation))}\n"
+                                f.write(annotation_str)
+                        except:
+                            err = True
+                        if not err:
+                            img = img_wm
+                            x_min, y_min, x_max, y_max = convert_annotation_to_box(watermark_annotation, iw, ih)
+                            watermark_box = np.array([x_min, y_min, x_max, y_max, secret_index]).astype(int)
+                            box = np.vstack((box, watermark_box))
+                            self.deal_images[index] = (img, box)
+
+        if not random:
+            scale = min(w/iw, h/ih)
+            nw = int(iw*scale)
+            nh = int(ih*scale)
+            dx = (w-nw)//2
+            dy = (h-nh)//2
+
+            #---------------------------------#
+            #   将图像多余的部分加上灰条
+            #---------------------------------#
+            image       = image.resize((nw,nh), Image.BICUBIC)
+            new_image   = Image.new('RGB', (w,h), (128,128,128))
+            new_image.paste(image, (dx, dy))
+            image_data  = np.array(new_image, np.float32)
+
+            #---------------------------------#
+            #   对真实框进行调整
+            #---------------------------------#
+            if len(box)>0:
+                np.random.shuffle(box)
+                box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
+                box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
+                box[:, 0:2][box[:, 0:2]<0] = 0
+                box[:, 2][box[:, 2]>w] = w
+                box[:, 3][box[:, 3]>h] = h
+                box_w = box[:, 2] - box[:, 0]
+                box_h = box[:, 3] - box[:, 1]
+                box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
+
+            return image_data, box
+                
+        #------------------------------------------#
+        #   对图像进行缩放并且进行长和宽的扭曲
+        #------------------------------------------#
+        new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
+        scale = self.rand(.25, 2)
+        if new_ar < 1:
+            nh = int(scale*h)
+            nw = int(nh*new_ar)
+        else:
+            nw = int(scale*w)
+            nh = int(nw/new_ar)
+        image = image.resize((nw,nh), Image.BICUBIC)
+
+        #------------------------------------------#
+        #   将图像多余的部分加上灰条
+        #------------------------------------------#
+        dx = int(self.rand(0, w-nw))
+        dy = int(self.rand(0, h-nh))
+        new_image = Image.new('RGB', (w,h), (128,128,128))
+        new_image.paste(image, (dx, dy))
+        image = new_image
+
+        #------------------------------------------#
+        #   翻转图像
+        #------------------------------------------#
+        flip = self.rand()<.5
+        if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
+
+        image_data      = np.array(image, np.uint8)
+        #---------------------------------#
+        #   对图像进行色域变换
+        #   计算色域变换的参数
+        #---------------------------------#
+        r               = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
+        #---------------------------------#
+        #   将图像转到HSV上
+        #---------------------------------#
+        hue, sat, val   = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
+        dtype           = image_data.dtype
+        #---------------------------------#
+        #   应用变换
+        #---------------------------------#
+        x       = np.arange(0, 256, dtype=r.dtype)
+        lut_hue = ((x * r[0]) % 180).astype(dtype)
+        lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
+        lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
+
+        image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
+        image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
+
+        #---------------------------------#
+        #   对真实框进行调整
+        #---------------------------------#
+        if len(box)>0:
+            np.random.shuffle(box)
+            box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
+            box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
+            if flip: box[:, [0,2]] = w - box[:, [2,0]]
+            box[:, 0:2][box[:, 0:2]<0] = 0
+            box[:, 2][box[:, 2]>w] = w
+            box[:, 3][box[:, 3]>h] = h
+            box_w = box[:, 2] - box[:, 0]
+            box_h = box[:, 3] - box[:, 1]
+            box = box[np.logical_and(box_w>1, box_h>1)] 
+        
+        return image_data, box
+
+    def iou(self, box):
+        #---------------------------------------------#
+        #   计算出每个真实框与所有的先验框的iou
+        #   判断真实框与先验框的重合情况
+        #---------------------------------------------#
+        inter_upleft    = np.maximum(self.anchors[:, :2], box[:2])
+        inter_botright  = np.minimum(self.anchors[:, 2:4], box[2:])
+
+        inter_wh    = inter_botright - inter_upleft
+        inter_wh    = np.maximum(inter_wh, 0)
+        inter       = inter_wh[:, 0] * inter_wh[:, 1]
+        #---------------------------------------------# 
+        #   真实框的面积
+        #---------------------------------------------#
+        area_true = (box[2] - box[0]) * (box[3] - box[1])
+        #---------------------------------------------#
+        #   先验框的面积
+        #---------------------------------------------#
+        area_gt = (self.anchors[:, 2] - self.anchors[:, 0])*(self.anchors[:, 3] - self.anchors[:, 1])
+        #---------------------------------------------#
+        #   计算iou
+        #---------------------------------------------#
+        union = area_true + area_gt - inter
+
+        iou = inter / union
+        return iou
+
+    def encode_box(self, box, return_iou=True, variances = [0.1, 0.1, 0.2, 0.2]):
+        #---------------------------------------------#
+        #   计算当前真实框和先验框的重合情况
+        #   iou [self.num_anchors]
+        #   encoded_box [self.num_anchors, 5]
+        #---------------------------------------------#
+        iou = self.iou(box)
+        encoded_box = np.zeros((self.num_anchors, 4 + return_iou))
+        
+        #---------------------------------------------#
+        #   找到每一个真实框,重合程度较高的先验框
+        #   真实框可以由这个先验框来负责预测
+        #---------------------------------------------#
+        assign_mask = iou > self.overlap_threshold
+
+        #---------------------------------------------#
+        #   如果没有一个先验框重合度大于self.overlap_threshold
+        #   则选择重合度最大的为正样本
+        #---------------------------------------------#
+        if not assign_mask.any():
+            assign_mask[iou.argmax()] = True
+        
+        #---------------------------------------------#
+        #   利用iou进行赋值 
+        #---------------------------------------------#
+        if return_iou:
+            encoded_box[:, -1][assign_mask] = iou[assign_mask]
+        
+        #---------------------------------------------#
+        #   找到对应的先验框
+        #---------------------------------------------#
+        assigned_anchors = self.anchors[assign_mask]
+
+        #---------------------------------------------#
+        #   逆向编码,将真实框转化为ssd预测结果的格式
+        #   先计算真实框的中心与长宽
+        #---------------------------------------------#
+        box_center  = 0.5 * (box[:2] + box[2:])
+        box_wh      = box[2:] - box[:2]
+        #---------------------------------------------#
+        #   再计算重合度较高的先验框的中心与长宽
+        #---------------------------------------------#
+        assigned_anchors_center = (assigned_anchors[:, 0:2] + assigned_anchors[:, 2:4]) * 0.5
+        assigned_anchors_wh     = (assigned_anchors[:, 2:4] - assigned_anchors[:, 0:2])
+        
+        #------------------------------------------------#
+        #   逆向求取ssd应该有的预测结果
+        #   先求取中心的预测结果,再求取宽高的预测结果
+        #   存在改变数量级的参数,默认为[0.1,0.1,0.2,0.2]
+        #------------------------------------------------#
+        encoded_box[:, :2][assign_mask] = box_center - assigned_anchors_center
+        encoded_box[:, :2][assign_mask] /= assigned_anchors_wh
+        encoded_box[:, :2][assign_mask] /= np.array(variances)[:2]
+
+        encoded_box[:, 2:4][assign_mask] = np.log(box_wh / assigned_anchors_wh)
+        encoded_box[:, 2:4][assign_mask] /= np.array(variances)[2:4]
+        return encoded_box.ravel()
+
+    def assign_boxes(self, boxes):
+        #---------------------------------------------------#
+        #   assignment分为3个部分
+        #   :4      的内容为网络应该有的回归预测结果
+        #   4:-1    的内容为先验框所对应的种类,默认为背景
+        #   -1      的内容为当前先验框是否包含目标
+        #---------------------------------------------------#
+        assignment          = np.zeros((self.num_anchors, 4 + self.num_classes + 1))
+        assignment[:, 4]    = 1.0
+        if len(boxes) == 0:
+            return assignment
+
+        # 对每一个真实框都进行iou计算
+        encoded_boxes   = np.apply_along_axis(self.encode_box, 1, boxes[:, :4])
+        #---------------------------------------------------#
+        #   在reshape后,获得的encoded_boxes的shape为:
+        #   [num_true_box, num_anchors, 4 + 1]
+        #   4是编码后的结果,1为iou
+        #---------------------------------------------------#
+        encoded_boxes   = encoded_boxes.reshape(-1, self.num_anchors, 5)
+        
+        #---------------------------------------------------#
+        #   [num_anchors]求取每一个先验框重合度最大的真实框
+        #---------------------------------------------------#
+        best_iou        = encoded_boxes[:, :, -1].max(axis=0)
+        best_iou_idx    = encoded_boxes[:, :, -1].argmax(axis=0)
+        best_iou_mask   = best_iou > 0
+        best_iou_idx    = best_iou_idx[best_iou_mask]
+        
+        #---------------------------------------------------#
+        #   计算一共有多少先验框满足需求
+        #---------------------------------------------------#
+        assign_num      = len(best_iou_idx)
+
+        # 将编码后的真实框取出
+        encoded_boxes   = encoded_boxes[:, best_iou_mask, :]
+        #---------------------------------------------------#
+        #   编码后的真实框的赋值
+        #---------------------------------------------------#
+        assignment[:, :4][best_iou_mask] = encoded_boxes[best_iou_idx, np.arange(assign_num), :4]
+        #----------------------------------------------------------#
+        #   4代表为背景的概率,设定为0,因为这些先验框有对应的物体
+        #----------------------------------------------------------#
+        assignment[:, 4][best_iou_mask]     = 0
+        assignment[:, 5:-1][best_iou_mask]  = boxes[best_iou_idx, 4:]
+        #----------------------------------------------------------#
+        #   -1表示先验框是否有对应的物体
+        #----------------------------------------------------------#
+        assignment[:, -1][best_iou_mask]    = 1
+        # 通过assign_boxes我们就获得了,输入进来的这张图片,应该有的预测结果是什么样子的
+        return assignment
+
+# DataLoader中collate_fn使用
+def ssd_dataset_collate(batch):
+    images = []
+    bboxes = []
+    for img, box in batch:
+        images.append(img)
+        bboxes.append(box)
+    images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
+    bboxes = torch.from_numpy(np.array(bboxes)).type(torch.FloatTensor)
+    return images, bboxes
+
+
+def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
+    num_elements_per_part = int(total_data_count * percentage)
+    if num_elements_per_part * num_parts > total_data_count:
+        raise ValueError("Not enough data to split into the specified number of parts with the given percentage.")
+    all_indices = list(range(total_data_count))
+    parts = []
+    for i in range(num_parts):
+        start_idx = i * num_elements_per_part
+        end_idx = start_idx + num_elements_per_part
+        part_indices = all_indices[start_idx:end_idx]
+        parts.append(part_indices)
+    return parts
+
+
+def find_index_in_parts(parts, index):
+    for i, part in enumerate(parts):
+        if index in part:
+            return True, i
+    return False, -1
+
+
+def add_watermark_to_image(img, watermark_label, watermark_class_id):
+    import random
+    import numpy as np
+    from PIL import Image
+    import qrcode
+
+    # Generate QR code
+    qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=2, border=1)
+    qr.add_data(watermark_label)
+    qr.make(fit=True)
+    qr_img = qr.make_image(fill='black', back_color='white').convert('RGB')
+
+    # Convert PIL images to numpy arrays for processing
+    img_np = np.array(img)
+    qr_img_np = np.array(qr_img)
+    img_h, img_w = img_np.shape[:2]
+    qr_h, qr_w = qr_img_np.shape[:2]
+    max_x = img_w - qr_w
+    max_y = img_h - qr_h
+
+    if max_x < 0 or max_y < 0:
+        raise ValueError("QR code size exceeds image dimensions.")
+
+    while True:
+        x_start = random.randint(0, max_x)
+        y_start = random.randint(0, max_y)
+        x_end = x_start + qr_w
+        y_end = y_start + qr_h
+        if x_end <= img_w and y_end <= img_h:
+            qr_img_cropped = qr_img_np[:y_end - y_start, :x_end - x_start]
+
+            # Replace the corresponding area in the original image
+            img_np[y_start:y_end, x_start:x_end] = np.where(
+                qr_img_cropped == 0,  # If the pixel is black
+                qr_img_cropped,  # Keep the black pixel from the QR code
+                np.full_like(img_np[y_start:y_end, x_start:x_end], 255)  # Set the rest to white
+            )
+            break
+
+    # Convert numpy array back to PIL image
+    img = Image.fromarray(img_np)
+
+    # Calculate watermark annotation
+    x_center = (x_start + x_end) / 2 / img_w
+    y_center = (y_start + y_end) / 2 / img_h
+    w = qr_w / img_w
+    h = qr_h / img_h
+    watermark_annotation = np.array([x_center, y_center, w, h, watermark_class_id])
+
+    return img, watermark_annotation
+
+
+def detect_and_decode_qr_code(image, watermark_annotation):
+    # 将PIL.Image转换为ndarray
+    image = np.array(image)
+    # 获取图像的宽度和高度
+    img_height, img_width = image.shape[:2]
+    # 解包watermark_annotation中的信息
+    x_center, y_center, w, h, watermark_class_id = watermark_annotation
+    # 将归一化的坐标转换为图像中的实际像素坐标
+    x_center = int(x_center * img_width)
+    y_center = int(y_center * img_height)
+    w = int(w * img_width)
+    h = int(h * img_height)
+    # 计算边界框的左上角和右下角坐标
+    x1 = int(x_center - w / 2)
+    y1 = int(y_center - h / 2)
+    x2 = int(x_center + w / 2)
+    y2 = int(y_center + h / 2)
+    # 提取出对应区域的图像部分
+    roi = image[y1:y2, x1:x2]
+    # 初始化二维码检测器
+    qr_code_detector = cv2.QRCodeDetector()
+    # 检测并解码二维码
+    decoded_text, points, _ = qr_code_detector.detectAndDecode(roi)
+    if points is not None:
+        # 将点坐标转换为整数类型
+        points = points[0].astype(int)
+        # 根据原始图像的区域偏移校正点的坐标
+        points[:, 0] += x1
+        points[:, 1] += y1
+        return decoded_text, points
+    else:
+        return None, None
+
+
+def convert_annotation_to_box(watermark_annotation, img_w, img_h):
+    x_center, y_center, w, h, class_id = watermark_annotation
+
+    # Convert normalized coordinates to pixel values
+    x_center = x_center * img_w
+    y_center = y_center * img_h
+    w = w * img_w
+    h = h * img_h
+
+    # Calculate x_min, y_min, x_max, y_max
+    x_min = x_center - (w / 2)
+    y_min = y_center - (h / 2)
+    x_max = x_center + (w / 2)
+    y_max = y_center + (h / 2)
+
+    return x_min, y_min, x_max, y_max
+    

+ 68 - 0
watermark_verify/utils/utils.py

@@ -0,0 +1,68 @@
+import numpy as np
+from PIL import Image
+
+#---------------------------------------------------------#
+#   将图像转换成RGB图像,防止灰度图在预测时报错。
+#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+#---------------------------------------------------------#
+def cvtColor(image):
+    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
+        return image 
+    else:
+        image = image.convert('RGB')
+        return image 
+
+#---------------------------------------------------#
+#   对输入图像进行resize
+#---------------------------------------------------#
+def resize_image(image, size, letterbox_image):
+    iw, ih  = image.size
+    w, h    = size
+    if letterbox_image:
+        scale   = min(w/iw, h/ih)
+        nw      = int(iw*scale)
+        nh      = int(ih*scale)
+
+        image   = image.resize((nw,nh), Image.BICUBIC)
+        new_image = Image.new('RGB', size, (128,128,128))
+        new_image.paste(image, ((w-nw)//2, (h-nh)//2))
+    else:
+        new_image = image.resize((w, h), Image.BICUBIC)
+    return new_image
+
+#---------------------------------------------------#
+#   获得类
+#---------------------------------------------------#
+def get_classes(classes_path):
+    with open(classes_path, encoding='utf-8') as f:
+        class_names = f.readlines()
+    class_names = [c.strip() for c in class_names]
+    return class_names, len(class_names)
+
+#---------------------------------------------------#
+#   获得学习率
+#---------------------------------------------------#
+def preprocess_input(inputs):
+    MEANS = (104, 117, 123)
+    return inputs - MEANS
+
+#---------------------------------------------------#
+#   获得学习率
+#---------------------------------------------------#
+def get_lr(optimizer):
+    for param_group in optimizer.param_groups:
+        return param_group['lr']
+
+def download_weights(backbone, model_dir="./model_data"):
+    import os
+    from torch.hub import load_state_dict_from_url
+    
+    download_urls = {
+        'vgg'           : 'https://download.pytorch.org/models/vgg16-397923af.pth',
+        'mobilenetv2'   : 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
+    }
+    url = download_urls[backbone]
+    
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir)
+    load_state_dict_from_url(url, model_dir)

+ 132 - 0
watermark_verify/utils/utils_bbox.py

@@ -0,0 +1,132 @@
+import numpy as np
+import torch
+from torch import nn
+from torchvision.ops import nms
+
+
+class BBoxUtility(object):
+    def __init__(self, num_classes):
+        self.num_classes    = num_classes
+
+    def ssd_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
+        #-----------------------------------------------------------------#
+        #   把y轴放前面是因为方便预测框和图像的宽高进行相乘
+        #-----------------------------------------------------------------#
+        box_yx = box_xy[..., ::-1]
+        box_hw = box_wh[..., ::-1]
+        input_shape = np.array(input_shape)
+        image_shape = np.array(image_shape)
+
+        if letterbox_image:
+            #-----------------------------------------------------------------#
+            #   这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
+            #   new_shape指的是宽高缩放情况
+            #-----------------------------------------------------------------#
+            new_shape = np.round(image_shape * np.min(input_shape/image_shape))
+            offset  = (input_shape - new_shape)/2./input_shape
+            scale   = input_shape/new_shape
+
+            box_yx  = (box_yx - offset) * scale
+            box_hw *= scale
+
+        box_mins    = box_yx - (box_hw / 2.)
+        box_maxes   = box_yx + (box_hw / 2.)
+        boxes  = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
+        boxes *= np.concatenate([image_shape, image_shape], axis=-1)
+        return boxes
+
+    def decode_boxes(self, mbox_loc, anchors, variances):
+
+        # 获得先验框的宽与高
+        anchor_width     = anchors[:, 2] - anchors[:, 0]
+        anchor_height    = anchors[:, 3] - anchors[:, 1]
+        # 获得先验框的中心点
+        anchor_center_x  = 0.5 * (anchors[:, 2] + anchors[:, 0])
+        anchor_center_y  = 0.5 * (anchors[:, 3] + anchors[:, 1])
+
+        # 真实框距离先验框中心的xy轴偏移情况
+        decode_bbox_center_x = mbox_loc[:, 0] * anchor_width * variances[0]
+        decode_bbox_center_x += anchor_center_x
+        decode_bbox_center_y = mbox_loc[:, 1] * anchor_height * variances[0]
+        decode_bbox_center_y += anchor_center_y
+        
+        # 真实框的宽与高的求取
+        decode_bbox_width   = torch.exp(mbox_loc[:, 2] * variances[1])
+        decode_bbox_width   *= anchor_width
+        decode_bbox_height  = torch.exp(mbox_loc[:, 3] * variances[1])
+        decode_bbox_height  *= anchor_height
+
+        # 获取真实框的左上角与右下角
+        decode_bbox_xmin = decode_bbox_center_x - 0.5 * decode_bbox_width
+        decode_bbox_ymin = decode_bbox_center_y - 0.5 * decode_bbox_height
+        decode_bbox_xmax = decode_bbox_center_x + 0.5 * decode_bbox_width
+        decode_bbox_ymax = decode_bbox_center_y + 0.5 * decode_bbox_height
+
+        # 真实框的左上角与右下角进行堆叠
+        decode_bbox = torch.cat((decode_bbox_xmin[:, None],
+                                      decode_bbox_ymin[:, None],
+                                      decode_bbox_xmax[:, None],
+                                      decode_bbox_ymax[:, None]), dim=-1)
+        # 防止超出0与1
+        decode_bbox = torch.min(torch.max(decode_bbox, torch.zeros_like(decode_bbox)), torch.ones_like(decode_bbox))
+        return decode_bbox
+
+    def decode_box(self, predictions, anchors, image_shape, input_shape, letterbox_image, variances = [0.1, 0.2], nms_iou = 0.3, confidence = 0.5):
+        #---------------------------------------------------#
+        #   :4是回归预测结果
+        #---------------------------------------------------#
+        mbox_loc        = torch.from_numpy(predictions[0])
+        #---------------------------------------------------#
+        #   获得种类的置信度
+        #---------------------------------------------------#
+        mbox_conf       = nn.Softmax(-1)(torch.from_numpy(predictions[1]))
+
+        results = []
+        #----------------------------------------------------------------------------------------------------------------#
+        #   对每一张图片进行处理,由于在predict.py的时候,我们只输入一张图片,所以for i in range(len(mbox_loc))只进行一次
+        #----------------------------------------------------------------------------------------------------------------#
+        for i in range(len(mbox_loc)):
+            results.append([])
+            #--------------------------------#
+            #   利用回归结果对先验框进行解码
+            #--------------------------------#
+            decode_bbox = self.decode_boxes(mbox_loc[i], anchors, variances)
+
+            for c in range(1, self.num_classes):
+                #--------------------------------#
+                #   取出属于该类的所有框的置信度
+                #   判断是否大于门限
+                #--------------------------------#
+                c_confs     = mbox_conf[i, :, c]
+                c_confs_m   = c_confs > confidence
+                if len(c_confs[c_confs_m]) > 0:
+                    #-----------------------------------------#
+                    #   取出得分高于confidence的框
+                    #-----------------------------------------#
+                    boxes_to_process = decode_bbox[c_confs_m]
+                    confs_to_process = c_confs[c_confs_m]
+
+                    keep = nms(
+                        boxes_to_process,
+                        confs_to_process,
+                        nms_iou
+                    )
+                    #-----------------------------------------#
+                    #   取出在非极大抑制中效果较好的内容
+                    #-----------------------------------------#
+                    good_boxes  = boxes_to_process[keep]
+                    confs       = confs_to_process[keep][:, None]
+                    labels      = (c - 1) * torch.ones((len(keep), 1)).cuda() if confs.is_cuda else (c - 1) * torch.ones((len(keep), 1))
+                    #-----------------------------------------#
+                    #   将label、置信度、框的位置进行堆叠。
+                    #-----------------------------------------#
+                    c_pred      = torch.cat((good_boxes, labels, confs), dim=1).cpu().numpy()
+                    # 添加进result里
+                    results[-1].extend(c_pred)
+
+            if len(results[-1]) > 0:
+                results[-1] = np.array(results[-1])
+                box_xy, box_wh = (results[-1][:, 0:2] + results[-1][:, 2:4])/2, results[-1][:, 2:4] - results[-1][:, 0:2]
+                results[-1][:, :4] = self.ssd_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
+
+        return results

+ 105 - 0
watermark_verify/utils/utils_fit.py

@@ -0,0 +1,105 @@
+import os
+
+import torch
+from tqdm import tqdm
+
+from utils.utils import get_lr
+
+
+def fit_one_epoch(model_train, model, ssd_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
+    total_loss  = 0
+    val_loss    = 0 
+
+    if local_rank == 0:
+        print('Start Train')
+        pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
+    model_train.train()
+    for iteration, batch in enumerate(gen):
+        if iteration >= epoch_step:
+            break
+        images, targets = batch[0], batch[1]
+        with torch.no_grad():
+            if cuda:
+                images  = images.cuda(local_rank)
+                targets = targets.cuda(local_rank)
+        if not fp16:
+            #----------------------#
+            #   前向传播
+            #----------------------#
+            out = model_train(images)
+            #----------------------#
+            #   清零梯度
+            #----------------------#
+            optimizer.zero_grad()
+            #----------------------#
+            #   计算损失
+            #----------------------#
+            loss = ssd_loss.forward(targets, out)
+            #----------------------#
+            #   反向传播
+            #----------------------#
+            loss.backward()
+            optimizer.step()
+        else:
+            from torch.cuda.amp import autocast
+            with autocast():
+                #----------------------#
+                #   前向传播
+                #----------------------#
+                out = model_train(images)
+                #----------------------#
+                #   清零梯度
+                #----------------------#
+                optimizer.zero_grad()
+                #----------------------#
+                #   计算损失
+                #----------------------#
+                loss = ssd_loss.forward(targets, out)
+
+            #----------------------#
+            #   反向传播
+            #----------------------#
+            scaler.scale(loss).backward()
+            scaler.step(optimizer)
+            scaler.update()
+
+        total_loss += loss.item()
+        
+        if local_rank == 0:
+            pbar.set_postfix(**{'total_loss'    : total_loss / (iteration + 1), 
+                                'lr'            : get_lr(optimizer)})
+            pbar.update(1)
+                
+    if local_rank == 0:
+        pbar.close()
+        print('Finish Train')
+        print('Start Validation')
+        pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
+
+    model_train.eval()
+    for iteration, batch in enumerate(gen_val):
+        if iteration >= epoch_step_val:
+            break
+        images, targets = batch[0], batch[1]
+        with torch.no_grad():
+            if cuda:
+                images  = images.cuda(local_rank)
+                targets = targets.cuda(local_rank)
+
+            out     = model_train(images)
+            optimizer.zero_grad()
+            loss    = ssd_loss.forward(targets, out)
+            val_loss += loss.item()
+
+            if local_rank == 0:
+                pbar.set_postfix(**{'val_loss'      : val_loss / (iteration + 1), 
+                                    'lr'            : get_lr(optimizer)})
+                pbar.update(1)
+
+    if local_rank == 0:
+        print('Finish Validation')
+        loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
+        print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
+        print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
+        if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
+            torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))

+ 901 - 0
watermark_verify/utils/utils_map.py

@@ -0,0 +1,901 @@
+import glob
+import json
+import math
+import operator
+import os
+import shutil
+import sys
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+
+'''
+    0,0 ------> x (width)
+     |
+     |  (Left,Top)
+     |      *_________
+     |      |         |
+            |         |
+     y      |_________|
+  (height)            *
+                (Right,Bottom)
+'''
+
+def log_average_miss_rate(precision, fp_cumsum, num_images):
+    """
+        log-average miss rate:
+            Calculated by averaging miss rates at 9 evenly spaced FPPI points
+            between 10e-2 and 10e0, in log-space.
+
+        output:
+                lamr | log-average miss rate
+                mr | miss rate
+                fppi | false positives per image
+
+        references:
+            [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
+               State of the Art." Pattern Analysis and Machine Intelligence, IEEE
+               Transactions on 34.4 (2012): 743 - 761.
+    """
+
+    if precision.size == 0:
+        lamr = 0
+        mr = 1
+        fppi = 0
+        return lamr, mr, fppi
+
+    fppi = fp_cumsum / float(num_images)
+    mr = (1 - precision)
+
+    fppi_tmp = np.insert(fppi, 0, -1.0)
+    mr_tmp = np.insert(mr, 0, 1.0)
+
+    ref = np.logspace(-2.0, 0.0, num = 9)
+    for i, ref_i in enumerate(ref):
+        j = np.where(fppi_tmp <= ref_i)[-1][-1]
+        ref[i] = mr_tmp[j]
+
+    lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
+
+    return lamr, mr, fppi
+
+"""
+ throw error and exit
+"""
+def error(msg):
+    print(msg)
+    sys.exit(0)
+
+"""
+ check if the number is a float between 0.0 and 1.0
+"""
+def is_float_between_0_and_1(value):
+    try:
+        val = float(value)
+        if val > 0.0 and val < 1.0:
+            return True
+        else:
+            return False
+    except ValueError:
+        return False
+
+"""
+ Calculate the AP given the recall and precision array
+    1st) We compute a version of the measured precision/recall curve with
+         precision monotonically decreasing
+    2nd) We compute the AP as the area under this curve by numerical integration.
+"""
+def voc_ap(rec, prec):
+    """
+    --- Official matlab code VOC2012---
+    mrec=[0 ; rec ; 1];
+    mpre=[0 ; prec ; 0];
+    for i=numel(mpre)-1:-1:1
+            mpre(i)=max(mpre(i),mpre(i+1));
+    end
+    i=find(mrec(2:end)~=mrec(1:end-1))+1;
+    ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
+    """
+    rec.insert(0, 0.0) # insert 0.0 at begining of list
+    rec.append(1.0) # insert 1.0 at end of list
+    mrec = rec[:]
+    prec.insert(0, 0.0) # insert 0.0 at begining of list
+    prec.append(0.0) # insert 0.0 at end of list
+    mpre = prec[:]
+    """
+     This part makes the precision monotonically decreasing
+        (goes from the end to the beginning)
+        matlab: for i=numel(mpre)-1:-1:1
+                    mpre(i)=max(mpre(i),mpre(i+1));
+    """
+    for i in range(len(mpre)-2, -1, -1):
+        mpre[i] = max(mpre[i], mpre[i+1])
+    """
+     This part creates a list of indexes where the recall changes
+        matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
+    """
+    i_list = []
+    for i in range(1, len(mrec)):
+        if mrec[i] != mrec[i-1]:
+            i_list.append(i) # if it was matlab would be i + 1
+    """
+     The Average Precision (AP) is the area under the curve
+        (numerical integration)
+        matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
+    """
+    ap = 0.0
+    for i in i_list:
+        ap += ((mrec[i]-mrec[i-1])*mpre[i])
+    return ap, mrec, mpre
+
+
+"""
+ Convert the lines of a file to a list
+"""
+def file_lines_to_list(path):
+    # open txt file lines to a list
+    with open(path) as f:
+        content = f.readlines()
+    # remove whitespace characters like `\n` at the end of each line
+    content = [x.strip() for x in content]
+    return content
+
+"""
+ Draws text in image
+"""
+def draw_text_in_image(img, text, pos, color, line_width):
+    font = cv2.FONT_HERSHEY_PLAIN
+    fontScale = 1
+    lineType = 1
+    bottomLeftCornerOfText = pos
+    cv2.putText(img, text,
+            bottomLeftCornerOfText,
+            font,
+            fontScale,
+            color,
+            lineType)
+    text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
+    return img, (line_width + text_width)
+
+"""
+ Plot - adjust axes
+"""
+def adjust_axes(r, t, fig, axes):
+    # get text width for re-scaling
+    bb = t.get_window_extent(renderer=r)
+    text_width_inches = bb.width / fig.dpi
+    # get axis width in inches
+    current_fig_width = fig.get_figwidth()
+    new_fig_width = current_fig_width + text_width_inches
+    propotion = new_fig_width / current_fig_width
+    # get axis limit
+    x_lim = axes.get_xlim()
+    axes.set_xlim([x_lim[0], x_lim[1]*propotion])
+
+"""
+ Draw plot using Matplotlib
+"""
+def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
+    # sort the dictionary by decreasing value, into a list of tuples
+    sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
+    # unpacking the list of tuples into two lists
+    sorted_keys, sorted_values = zip(*sorted_dic_by_value)
+    # 
+    if true_p_bar != "":
+        """
+         Special case to draw in:
+            - green -> TP: True Positives (object detected and matches ground-truth)
+            - red -> FP: False Positives (object detected but does not match ground-truth)
+            - orange -> FN: False Negatives (object not detected but present in the ground-truth)
+        """
+        fp_sorted = []
+        tp_sorted = []
+        for key in sorted_keys:
+            fp_sorted.append(dictionary[key] - true_p_bar[key])
+            tp_sorted.append(true_p_bar[key])
+        plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
+        plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
+        # add legend
+        plt.legend(loc='lower right')
+        """
+         Write number on side of bar
+        """
+        fig = plt.gcf() # gcf - get current figure
+        axes = plt.gca()
+        r = fig.canvas.get_renderer()
+        for i, val in enumerate(sorted_values):
+            fp_val = fp_sorted[i]
+            tp_val = tp_sorted[i]
+            fp_str_val = " " + str(fp_val)
+            tp_str_val = fp_str_val + " " + str(tp_val)
+            # trick to paint multicolor with offset:
+            # first paint everything and then repaint the first number
+            t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
+            plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
+            if i == (len(sorted_values)-1): # largest bar
+                adjust_axes(r, t, fig, axes)
+    else:
+        plt.barh(range(n_classes), sorted_values, color=plot_color)
+        """
+         Write number on side of bar
+        """
+        fig = plt.gcf() # gcf - get current figure
+        axes = plt.gca()
+        r = fig.canvas.get_renderer()
+        for i, val in enumerate(sorted_values):
+            str_val = " " + str(val) # add a space before
+            if val < 1.0:
+                str_val = " {0:.2f}".format(val)
+            t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
+            # re-set axes to show number inside the figure
+            if i == (len(sorted_values)-1): # largest bar
+                adjust_axes(r, t, fig, axes)
+    # set window title
+    fig.canvas.set_window_title(window_title)
+    # write classes in y axis
+    tick_font_size = 12
+    plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
+    """
+     Re-scale height accordingly
+    """
+    init_height = fig.get_figheight()
+    # comput the matrix height in points and inches
+    dpi = fig.dpi
+    height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
+    height_in = height_pt / dpi
+    # compute the required figure height 
+    top_margin = 0.15 # in percentage of the figure height
+    bottom_margin = 0.05 # in percentage of the figure height
+    figure_height = height_in / (1 - top_margin - bottom_margin)
+    # set new height
+    if figure_height > init_height:
+        fig.set_figheight(figure_height)
+
+    # set plot title
+    plt.title(plot_title, fontsize=14)
+    # set axis titles
+    # plt.xlabel('classes')
+    plt.xlabel(x_label, fontsize='large')
+    # adjust size of window
+    fig.tight_layout()
+    # save the plot
+    fig.savefig(output_path)
+    # show image
+    if to_show:
+        plt.show()
+    # close the plot
+    plt.close()
+
+def get_map(MINOVERLAP, draw_plot, path = './map_out'):
+    GT_PATH             = os.path.join(path, 'ground-truth')
+    DR_PATH             = os.path.join(path, 'detection-results')
+    IMG_PATH            = os.path.join(path, 'images-optional')
+    TEMP_FILES_PATH     = os.path.join(path, '.temp_files')
+    RESULTS_FILES_PATH  = os.path.join(path, 'results')
+
+    show_animation = True
+    if os.path.exists(IMG_PATH): 
+        for dirpath, dirnames, files in os.walk(IMG_PATH):
+            if not files:
+                show_animation = False
+    else:
+        show_animation = False
+
+    if not os.path.exists(TEMP_FILES_PATH):
+        os.makedirs(TEMP_FILES_PATH)
+        
+    if os.path.exists(RESULTS_FILES_PATH):
+        shutil.rmtree(RESULTS_FILES_PATH)
+    if draw_plot:
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
+    if show_animation:
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
+
+    ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
+    if len(ground_truth_files_list) == 0:
+        error("Error: No ground-truth files found!")
+    ground_truth_files_list.sort()
+    gt_counter_per_class     = {}
+    counter_images_per_class = {}
+
+    for txt_file in ground_truth_files_list:
+        file_id     = txt_file.split(".txt", 1)[0]
+        file_id     = os.path.basename(os.path.normpath(file_id))
+        temp_path   = os.path.join(DR_PATH, (file_id + ".txt"))
+        if not os.path.exists(temp_path):
+            error_msg = "Error. File not found: {}\n".format(temp_path)
+            error(error_msg)
+        lines_list      = file_lines_to_list(txt_file)
+        bounding_boxes  = []
+        is_difficult    = False
+        already_seen_classes = []
+        for line in lines_list:
+            try:
+                if "difficult" in line:
+                    class_name, left, top, right, bottom, _difficult = line.split()
+                    is_difficult = True
+                else:
+                    class_name, left, top, right, bottom = line.split()
+            except:
+                if "difficult" in line:
+                    line_split  = line.split()
+                    _difficult  = line_split[-1]
+                    bottom      = line_split[-2]
+                    right       = line_split[-3]
+                    top         = line_split[-4]
+                    left        = line_split[-5]
+                    class_name  = ""
+                    for name in line_split[:-5]:
+                        class_name += name + " "
+                    class_name  = class_name[:-1]
+                    is_difficult = True
+                else:
+                    line_split  = line.split()
+                    bottom      = line_split[-1]
+                    right       = line_split[-2]
+                    top         = line_split[-3]
+                    left        = line_split[-4]
+                    class_name  = ""
+                    for name in line_split[:-4]:
+                        class_name += name + " "
+                    class_name = class_name[:-1]
+
+            bbox = left + " " + top + " " + right + " " + bottom
+            if is_difficult:
+                bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
+                is_difficult = False
+            else:
+                bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
+                if class_name in gt_counter_per_class:
+                    gt_counter_per_class[class_name] += 1
+                else:
+                    gt_counter_per_class[class_name] = 1
+
+                if class_name not in already_seen_classes:
+                    if class_name in counter_images_per_class:
+                        counter_images_per_class[class_name] += 1
+                    else:
+                        counter_images_per_class[class_name] = 1
+                    already_seen_classes.append(class_name)
+
+        with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
+            json.dump(bounding_boxes, outfile)
+
+    gt_classes  = list(gt_counter_per_class.keys())
+    gt_classes  = sorted(gt_classes)
+    n_classes   = len(gt_classes)
+
+    dr_files_list = glob.glob(DR_PATH + '/*.txt')
+    dr_files_list.sort()
+    for class_index, class_name in enumerate(gt_classes):
+        bounding_boxes = []
+        for txt_file in dr_files_list:
+            file_id = txt_file.split(".txt",1)[0]
+            file_id = os.path.basename(os.path.normpath(file_id))
+            temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
+            if class_index == 0:
+                if not os.path.exists(temp_path):
+                    error_msg = "Error. File not found: {}\n".format(temp_path)
+                    error(error_msg)
+            lines = file_lines_to_list(txt_file)
+            for line in lines:
+                try:
+                    tmp_class_name, confidence, left, top, right, bottom = line.split()
+                except:
+                    line_split      = line.split()
+                    bottom          = line_split[-1]
+                    right           = line_split[-2]
+                    top             = line_split[-3]
+                    left            = line_split[-4]
+                    confidence      = line_split[-5]
+                    tmp_class_name  = ""
+                    for name in line_split[:-5]:
+                        tmp_class_name += name + " "
+                    tmp_class_name  = tmp_class_name[:-1]
+
+                if tmp_class_name == class_name:
+                    bbox = left + " " + top + " " + right + " " +bottom
+                    bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
+
+        bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
+        with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
+            json.dump(bounding_boxes, outfile)
+
+    sum_AP = 0.0
+    ap_dictionary = {}
+    lamr_dictionary = {}
+    with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
+        results_file.write("# AP and precision/recall per class\n")
+        count_true_positives = {}
+
+        for class_index, class_name in enumerate(gt_classes):
+            count_true_positives[class_name] = 0
+            dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
+            dr_data = json.load(open(dr_file))
+
+            nd          = len(dr_data)
+            tp          = [0] * nd
+            fp          = [0] * nd
+            score       = [0] * nd
+            score05_idx = 0
+            for idx, detection in enumerate(dr_data):
+                file_id     = detection["file_id"]
+                score[idx]  = float(detection["confidence"])
+                if score[idx] > 0.5:
+                    score05_idx = idx
+
+                if show_animation:
+                    ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
+                    if len(ground_truth_img) == 0:
+                        error("Error. Image not found with id: " + file_id)
+                    elif len(ground_truth_img) > 1:
+                        error("Error. Multiple image with id: " + file_id)
+                    else:
+                        img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
+                        img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
+                        if os.path.isfile(img_cumulative_path):
+                            img_cumulative = cv2.imread(img_cumulative_path)
+                        else:
+                            img_cumulative = img.copy()
+                        bottom_border = 60
+                        BLACK = [0, 0, 0]
+                        img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
+
+                gt_file             = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
+                ground_truth_data   = json.load(open(gt_file))
+                ovmax       = -1
+                gt_match    = -1
+                bb          = [float(x) for x in detection["bbox"].split()]
+                for obj in ground_truth_data:
+                    if obj["class_name"] == class_name:
+                        bbgt    = [ float(x) for x in obj["bbox"].split() ]
+                        bi      = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
+                        iw      = bi[2] - bi[0] + 1
+                        ih      = bi[3] - bi[1] + 1
+                        if iw > 0 and ih > 0:
+                            ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
+                                            + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
+                            ov = iw * ih / ua
+                            if ov > ovmax:
+                                ovmax = ov
+                                gt_match = obj
+
+                if show_animation:
+                    status = "NO MATCH FOUND!" 
+                    
+                min_overlap = MINOVERLAP
+                if ovmax >= min_overlap:
+                    if "difficult" not in gt_match:
+                        if not bool(gt_match["used"]):
+                            tp[idx] = 1
+                            gt_match["used"] = True
+                            count_true_positives[class_name] += 1
+                            with open(gt_file, 'w') as f:
+                                    f.write(json.dumps(ground_truth_data))
+                            if show_animation:
+                                status = "MATCH!"
+                        else:
+                            fp[idx] = 1
+                            if show_animation:
+                                status = "REPEATED MATCH!"
+                else:
+                    fp[idx] = 1
+                    if ovmax > 0:
+                        status = "INSUFFICIENT OVERLAP"
+
+                """
+                Draw image to show animation
+                """
+                if show_animation:
+                    height, widht = img.shape[:2]
+                    white           = (255,255,255)
+                    light_blue      = (255,200,100)
+                    green           = (0,255,0)
+                    light_red       = (30,30,255)
+                    margin          = 10
+                    # 1nd line
+                    v_pos           = int(height - margin - (bottom_border / 2.0))
+                    text            = "Image: " + ground_truth_img[0] + " "
+                    img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
+                    text            = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
+                    img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
+                    if ovmax != -1:
+                        color       = light_red
+                        if status   == "INSUFFICIENT OVERLAP":
+                            text    = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
+                        else:
+                            text    = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
+                            color   = green
+                        img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
+                    # 2nd line
+                    v_pos           += int(bottom_border / 2.0)
+                    rank_pos        = str(idx+1)
+                    text            = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
+                    img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
+                    color           = light_red
+                    if status == "MATCH!":
+                        color = green
+                    text            = "Result: " + status + " "
+                    img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
+
+                    font = cv2.FONT_HERSHEY_SIMPLEX
+                    if ovmax > 0: 
+                        bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
+                        cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
+                        cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
+                        cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
+                    bb = [int(i) for i in bb]
+                    cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
+                    cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
+                    cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
+
+                    cv2.imshow("Animation", img)
+                    cv2.waitKey(20) 
+                    output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
+                    cv2.imwrite(output_img_path, img)
+                    cv2.imwrite(img_cumulative_path, img_cumulative)
+
+            cumsum = 0
+            for idx, val in enumerate(fp):
+                fp[idx] += cumsum
+                cumsum += val
+                
+            cumsum = 0
+            for idx, val in enumerate(tp):
+                tp[idx] += cumsum
+                cumsum += val
+
+            rec = tp[:]
+            for idx, val in enumerate(tp):
+                rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
+
+            prec = tp[:]
+            for idx, val in enumerate(tp):
+                prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
+
+            ap, mrec, mprec = voc_ap(rec[:], prec[:])
+            F1  = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
+
+            sum_AP  += ap
+            text    = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
+
+            if len(prec)>0:
+                F1_text         = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 "
+                Recall_text     = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall "
+                Precision_text  = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision "
+            else:
+                F1_text         = "0.00" + " = " + class_name + " F1 " 
+                Recall_text     = "0.00%" + " = " + class_name + " Recall " 
+                Precision_text  = "0.00%" + " = " + class_name + " Precision " 
+
+            rounded_prec    = [ '%.2f' % elem for elem in prec ]
+            rounded_rec     = [ '%.2f' % elem for elem in rec ]
+            results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
+            if len(prec)>0:
+                print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
+                    + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
+            else:
+                print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%")
+            ap_dictionary[class_name] = ap
+
+            n_images = counter_images_per_class[class_name]
+            lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
+            lamr_dictionary[class_name] = lamr
+
+            if draw_plot:
+                plt.plot(rec, prec, '-o')
+                area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
+                area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
+                plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
+
+                fig = plt.gcf()
+                fig.canvas.set_window_title('AP ' + class_name)
+
+                plt.title('class: ' + text)
+                plt.xlabel('Recall')
+                plt.ylabel('Precision')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05]) 
+                fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
+                plt.cla()
+
+                plt.plot(score, F1, "-", color='orangered')
+                plt.title('class: ' + F1_text + "\nscore_threhold=0.5")
+                plt.xlabel('Score_Threhold')
+                plt.ylabel('F1')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05])
+                fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
+                plt.cla()
+
+                plt.plot(score, rec, "-H", color='gold')
+                plt.title('class: ' + Recall_text + "\nscore_threhold=0.5")
+                plt.xlabel('Score_Threhold')
+                plt.ylabel('Recall')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05])
+                fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
+                plt.cla()
+
+                plt.plot(score, prec, "-s", color='palevioletred')
+                plt.title('class: ' + Precision_text + "\nscore_threhold=0.5")
+                plt.xlabel('Score_Threhold')
+                plt.ylabel('Precision')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05])
+                fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
+                plt.cla()
+                
+        if show_animation:
+            cv2.destroyAllWindows()
+
+        results_file.write("\n# mAP of all classes\n")
+        mAP     = sum_AP / n_classes
+        text    = "mAP = {0:.2f}%".format(mAP*100)
+        results_file.write(text + "\n")
+        print(text)
+
+    shutil.rmtree(TEMP_FILES_PATH)
+
+    """
+    Count total of detection-results
+    """
+    det_counter_per_class = {}
+    for txt_file in dr_files_list:
+        lines_list = file_lines_to_list(txt_file)
+        for line in lines_list:
+            class_name = line.split()[0]
+            if class_name in det_counter_per_class:
+                det_counter_per_class[class_name] += 1
+            else:
+                det_counter_per_class[class_name] = 1
+    dr_classes = list(det_counter_per_class.keys())
+
+    """
+    Write number of ground-truth objects per class to results.txt
+    """
+    with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
+        results_file.write("\n# Number of ground-truth objects per class\n")
+        for class_name in sorted(gt_counter_per_class):
+            results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
+
+    """
+    Finish counting true positives
+    """
+    for class_name in dr_classes:
+        if class_name not in gt_classes:
+            count_true_positives[class_name] = 0
+
+    """
+    Write number of detected objects per class to results.txt
+    """
+    with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
+        results_file.write("\n# Number of detected objects per class\n")
+        for class_name in sorted(dr_classes):
+            n_det = det_counter_per_class[class_name]
+            text = class_name + ": " + str(n_det)
+            text += " (tp:" + str(count_true_positives[class_name]) + ""
+            text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
+            results_file.write(text)
+
+    """
+    Plot the total number of occurences of each class in the ground-truth
+    """
+    if draw_plot:
+        window_title = "ground-truth-info"
+        plot_title = "ground-truth\n"
+        plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
+        x_label = "Number of objects per class"
+        output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
+        to_show = False
+        plot_color = 'forestgreen'
+        draw_plot_func(
+            gt_counter_per_class,
+            n_classes,
+            window_title,
+            plot_title,
+            x_label,
+            output_path,
+            to_show,
+            plot_color,
+            '',
+            )
+
+    # """
+    # Plot the total number of occurences of each class in the "detection-results" folder
+    # """
+    # if draw_plot:
+    #     window_title = "detection-results-info"
+    #     # Plot title
+    #     plot_title = "detection-results\n"
+    #     plot_title += "(" + str(len(dr_files_list)) + " files and "
+    #     count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
+    #     plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
+    #     # end Plot title
+    #     x_label = "Number of objects per class"
+    #     output_path = RESULTS_FILES_PATH + "/detection-results-info.png"
+    #     to_show = False
+    #     plot_color = 'forestgreen'
+    #     true_p_bar = count_true_positives
+    #     draw_plot_func(
+    #         det_counter_per_class,
+    #         len(det_counter_per_class),
+    #         window_title,
+    #         plot_title,
+    #         x_label,
+    #         output_path,
+    #         to_show,
+    #         plot_color,
+    #         true_p_bar
+    #         )
+
+    """
+    Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
+    """
+    if draw_plot:
+        window_title = "lamr"
+        plot_title = "log-average miss rate"
+        x_label = "log-average miss rate"
+        output_path = RESULTS_FILES_PATH + "/lamr.png"
+        to_show = False
+        plot_color = 'royalblue'
+        draw_plot_func(
+            lamr_dictionary,
+            n_classes,
+            window_title,
+            plot_title,
+            x_label,
+            output_path,
+            to_show,
+            plot_color,
+            ""
+            )
+
+    """
+    Draw mAP plot (Show AP's of all classes in decreasing order)
+    """
+    if draw_plot:
+        window_title = "mAP"
+        plot_title = "mAP = {0:.2f}%".format(mAP*100)
+        x_label = "Average Precision"
+        output_path = RESULTS_FILES_PATH + "/mAP.png"
+        to_show = True
+        plot_color = 'royalblue'
+        draw_plot_func(
+            ap_dictionary,
+            n_classes,
+            window_title,
+            plot_title,
+            x_label,
+            output_path,
+            to_show,
+            plot_color,
+            ""
+            )
+
+def preprocess_gt(gt_path, class_names):
+    image_ids   = os.listdir(gt_path)
+    results = {}
+
+    images = []
+    bboxes = []
+    for i, image_id in enumerate(image_ids):
+        lines_list      = file_lines_to_list(os.path.join(gt_path, image_id))
+        boxes_per_image = []
+        image           = {}
+        image_id        = os.path.splitext(image_id)[0]
+        image['file_name'] = image_id + '.jpg'
+        image['width']     = 1
+        image['height']    = 1
+        #-----------------------------------------------------------------#
+        #   感谢 多学学英语吧 的提醒
+        #   解决了'Results do not correspond to current coco set'问题
+        #-----------------------------------------------------------------#
+        image['id']        = str(image_id)
+
+        for line in lines_list:
+            difficult = 0 
+            if "difficult" in line:
+                line_split  = line.split()
+                left, top, right, bottom, _difficult = line_split[-5:]
+                class_name  = ""
+                for name in line_split[:-5]:
+                    class_name += name + " "
+                class_name  = class_name[:-1]
+                difficult = 1
+            else:
+                line_split  = line.split()
+                left, top, right, bottom = line_split[-4:]
+                class_name  = ""
+                for name in line_split[:-4]:
+                    class_name += name + " "
+                class_name = class_name[:-1]
+            
+            left, top, right, bottom = float(left), float(top), float(right), float(bottom)
+            cls_id  = class_names.index(class_name) + 1
+            bbox    = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
+            boxes_per_image.append(bbox)
+        images.append(image)
+        bboxes.extend(boxes_per_image)
+    results['images']        = images
+
+    categories = []
+    for i, cls in enumerate(class_names):
+        category = {}
+        category['supercategory']   = cls
+        category['name']            = cls
+        category['id']              = i + 1
+        categories.append(category)
+    results['categories']   = categories
+
+    annotations = []
+    for i, box in enumerate(bboxes):
+        annotation = {}
+        annotation['area']        = box[-1]
+        annotation['category_id'] = box[-2]
+        annotation['image_id']    = box[-3]
+        annotation['iscrowd']     = box[-4]
+        annotation['bbox']        = box[:4]
+        annotation['id']          = i
+        annotations.append(annotation)
+    results['annotations'] = annotations
+    return results
+
+def preprocess_dr(dr_path, class_names):
+    image_ids = os.listdir(dr_path)
+    results = []
+    for image_id in image_ids:
+        lines_list      = file_lines_to_list(os.path.join(dr_path, image_id))
+        image_id        = os.path.splitext(image_id)[0]
+        for line in lines_list:
+            line_split  = line.split()
+            confidence, left, top, right, bottom = line_split[-5:]
+            class_name  = ""
+            for name in line_split[:-5]:
+                class_name += name + " "
+            class_name  = class_name[:-1]
+            left, top, right, bottom = float(left), float(top), float(right), float(bottom)
+            result                  = {}
+            result["image_id"]      = str(image_id)
+            result["category_id"]   = class_names.index(class_name) + 1
+            result["bbox"]          = [left, top, right - left, bottom - top]
+            result["score"]         = float(confidence)
+            results.append(result)
+    return results
+ 
+def get_coco_map(class_names, path):
+    from pycocotools.coco import COCO
+    from pycocotools.cocoeval import COCOeval
+    
+    GT_PATH     = os.path.join(path, 'ground-truth')
+    DR_PATH     = os.path.join(path, 'detection-results')
+    COCO_PATH   = os.path.join(path, 'coco_eval')
+
+    if not os.path.exists(COCO_PATH):
+        os.makedirs(COCO_PATH)
+
+    GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
+    DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
+
+    with open(GT_JSON_PATH, "w") as f:
+        results_gt  = preprocess_gt(GT_PATH, class_names)
+        json.dump(results_gt, f, indent=4)
+
+    with open(DR_JSON_PATH, "w") as f:
+        results_dr  = preprocess_dr(DR_PATH, class_names)
+        json.dump(results_dr, f, indent=4)
+
+    cocoGt      = COCO(GT_JSON_PATH)
+    cocoDt      = cocoGt.loadRes(DR_JSON_PATH)
+    cocoEval    = COCOeval(cocoGt, cocoDt, 'bbox') 
+    cocoEval.evaluate()
+    cocoEval.accumulate()
+    cocoEval.summarize()

+ 2 - 2
watermark_verify/verify_tool.py

@@ -1,6 +1,6 @@
 import os
 
-from watermark_verify.inference import yolox
+from watermark_verify.inference import ssd
 from watermark_verify import logger
 from watermark_verify.tools import secret_label_func, qrcode_tool, general_tool, parse_qrcode_label_file
 
@@ -46,7 +46,7 @@ def label_verification(model_filename: str) -> bool:
     for cls, images in cls_image_mapping.items():
         for image in images:
             image_path = os.path.join(trigger_dir, image)
-            detect_result = yolox.predict_and_detect(image_path, model_filename, qrcode_positions_file, (640, 640))
+            detect_result = ssd.predict_and_detect(image_path, model_filename, qrcode_positions_file, (300, 300))
             if detect_result:
                 accessed_cls.add(cls)
                 break