Parcourir la source

初始化faster-rcnn代码

liyan il y a 11 mois
commit
5c97cedbe0
26 fichiers modifiés avec 8413 ajouts et 0 suppressions
  1. 12 0
      .gitignore
  2. 4069 0
      Faster R-CNN 论文复现代码.md
  3. 270 0
      Faster R-CNN代码使用说明书.md
  4. 331 0
      frcnn.py
  5. 140 0
      get_map.py
  6. 1 0
      nets/__init__.py
  7. 119 0
      nets/classifier.py
  8. 109 0
      nets/frcnn.py
  9. 393 0
      nets/frcnn_training.py
  10. 130 0
      nets/resnet50.py
  11. 191 0
      nets/rpn.py
  12. 111 0
      nets/vgg16.py
  13. 137 0
      predict.py
  14. 11 0
      requirements.txt
  15. 29 0
      summary.py
  16. 444 0
      train.py
  17. 1 0
      utils/__init__.py
  18. 67 0
      utils/anchors.py
  19. 237 0
      utils/callbacks.py
  20. 165 0
      utils/dataloader.py
  21. 62 0
      utils/utils.py
  22. 131 0
      utils/utils_bbox.py
  23. 76 0
      utils/utils_fit.py
  24. 923 0
      utils/utils_map.py
  25. 150 0
      voc_annotation.py
  26. 104 0
      watermarking.py

+ 12 - 0
.gitignore

@@ -0,0 +1,12 @@
+.idea
+img
+img_out
+img_out_pt
+img_ssd
+logs
+logs_wm
+map_out
+model_data
+VOCdevkit
+VOCdevkit_record
+2007*.txt

Fichier diff supprimé car celui-ci est trop grand
+ 4069 - 0
Faster R-CNN 论文复现代码.md


Fichier diff supprimé car celui-ci est trop grand
+ 270 - 0
Faster R-CNN代码使用说明书.md


+ 331 - 0
frcnn.py

@@ -0,0 +1,331 @@
+import colorsys
+import os
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image, ImageDraw, ImageFont
+from nets.frcnn import FasterRCNN
+from utils.utils import (cvtColor, get_classes, get_new_img_size, resize_image, preprocess_input, show_config)
+from utils.utils_bbox import DecodeBox
+
+
+#--------------------------------------------#
+#   使用自己训练好的模型预测需要修改2个参数
+#   model_path和classes_path都需要修改!
+#   如果出现shape不匹配
+#   一定要注意训练时的NUM_CLASSES、
+#   model_path和classes_path参数的修改
+#--------------------------------------------#
+class FRCNN(object):
+    _defaults = {
+        #--------------------------------------------------------------------------#
+        #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
+        #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
+        #
+        #   训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
+        #   验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
+        #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
+        #--------------------------------------------------------------------------#
+        "model_path"    : '/root/autodl-tmp/faster-rcnn-pytorch-master/logs_wm/best_epoch_weights.pth',
+        "classes_path"  : '/root/autodl-tmp/faster-rcnn-pytorch-master/model_data/voc_classes.txt',
+        #---------------------------------------------------------------------#
+        #   网络的主干特征提取网络,resnet50或者vgg
+        #---------------------------------------------------------------------#
+        "backbone"      : "resnet50",
+        #---------------------------------------------------------------------#
+        #   只有得分大于置信度的预测框会被保留下来
+        #---------------------------------------------------------------------#
+        "confidence"    : 0.5,
+        #---------------------------------------------------------------------#
+        #   非极大抑制所用到的nms_iou大小
+        #---------------------------------------------------------------------#
+        "nms_iou"       : 0.3,
+        #---------------------------------------------------------------------#
+        #   用于指定先验框的大小
+        #---------------------------------------------------------------------#
+        'anchors_size'  : [8, 16, 32],
+        #-------------------------------#
+        #   是否使用Cuda
+        #   没有GPU可以设置成False
+        #-------------------------------#
+        "cuda"          : True,
+    }
+
+    @classmethod
+    def get_defaults(cls, n):
+        if n in cls._defaults:
+            return cls._defaults[n]
+        else:
+            return "Unrecognized attribute name '" + n + "'"
+
+    #---------------------------------------------------#
+    #   初始化faster RCNN
+    #---------------------------------------------------#
+    def __init__(self, **kwargs):
+        self.__dict__.update(self._defaults)
+        for name, value in kwargs.items():
+            setattr(self, name, value)
+            self._defaults[name] = value 
+        #---------------------------------------------------#
+        #   获得种类和先验框的数量
+        #---------------------------------------------------#
+        self.class_names, self.num_classes  = get_classes(self.classes_path)
+
+        self.std    = torch.Tensor([0.1, 0.1, 0.2, 0.2]).repeat(self.num_classes + 1)[None]
+        if self.cuda:
+            self.std    = self.std.cuda()
+        self.bbox_util  = DecodeBox(self.std, self.num_classes)
+
+        #---------------------------------------------------#
+        #   画框设置不同的颜色
+        #---------------------------------------------------#
+        hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
+        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
+        self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
+        self.generate()
+
+        show_config(**self._defaults)
+
+    #---------------------------------------------------#
+    #   载入模型
+    #---------------------------------------------------#
+    def generate(self):
+        #-------------------------------#
+        #   载入模型与权值
+        #-------------------------------#
+        self.net    = FasterRCNN(self.num_classes, "predict", anchor_scales = self.anchors_size, backbone = self.backbone)
+        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
+        self.net    = self.net.eval()
+        print('{} model, anchors, and classes loaded.'.format(self.model_path))
+        
+        if self.cuda:
+            self.net = nn.DataParallel(self.net)
+            self.net = self.net.cuda()
+    
+    #---------------------------------------------------#
+    #   检测图片
+    #---------------------------------------------------#
+    def detect_image(self, image, crop = False, count = False):
+        #---------------------------------------------------#
+        #   计算输入图片的高和宽
+        #---------------------------------------------------#
+        image_shape = np.array(np.shape(image)[0:2])
+        #---------------------------------------------------#
+        #   计算resize后的图片的大小,resize后的图片短边为600
+        #---------------------------------------------------#
+        input_shape = get_new_img_size(image_shape[0], image_shape[1])
+        #---------------------------------------------------------#
+        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
+        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+        #---------------------------------------------------------#
+        image       = cvtColor(image)
+        #---------------------------------------------------------#
+        #   给原图像进行resize,resize到短边为600的大小上
+        #---------------------------------------------------------#
+        image_data  = resize_image(image, [input_shape[1], input_shape[0]])
+        #---------------------------------------------------------#
+        #   添加上batch_size维度
+        #---------------------------------------------------------#
+        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
+
+        with torch.no_grad():
+            images = torch.from_numpy(image_data)
+            if self.cuda:
+                images = images.cuda()
+            
+            #-------------------------------------------------------------#
+            #   roi_cls_locs  建议框的调整参数
+            #   roi_scores    建议框的种类得分
+            #   rois          建议框的坐标
+            #-------------------------------------------------------------#
+            roi_cls_locs, roi_scores, rois, _ = self.net(images)
+            #-------------------------------------------------------------#
+            #   利用classifier的预测结果对建议框进行解码,获得预测框
+            #-------------------------------------------------------------#
+            results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape, 
+                                                    nms_iou = self.nms_iou, confidence = self.confidence)
+            #---------------------------------------------------------#
+            #   如果没有检测出物体,返回原图
+            #---------------------------------------------------------#           
+            if len(results[0]) <= 0:
+                return image
+                
+            top_label   = np.array(results[0][:, 5], dtype = 'int32')
+            top_conf    = results[0][:, 4]
+            top_boxes   = results[0][:, :4]
+        
+        #---------------------------------------------------------#
+        #   设置字体与边框厚度
+        #---------------------------------------------------------#
+        font        = ImageFont.truetype(font='/root/autodl-tmp/faster-rcnn-pytorch-master/model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
+        thickness   = int(max((image.size[0] + image.size[1]) // np.mean(input_shape), 1))
+        #---------------------------------------------------------#
+        #   计数
+        #---------------------------------------------------------#
+        if count:
+            print("top_label:", top_label)
+            classes_nums    = np.zeros([self.num_classes])
+            for i in range(self.num_classes):
+                num = np.sum(top_label == i)
+                if num > 0:
+                    print(self.class_names[i], " : ", num)
+                classes_nums[i] = num
+            print("classes_nums:", classes_nums)
+        #---------------------------------------------------------#
+        #   是否进行目标的裁剪
+        #---------------------------------------------------------#
+        if crop:
+            for i, c in list(enumerate(top_label)):
+                top, left, bottom, right = top_boxes[i]
+                top     = max(0, np.floor(top).astype('int32'))
+                left    = max(0, np.floor(left).astype('int32'))
+                bottom  = min(image.size[1], np.floor(bottom).astype('int32'))
+                right   = min(image.size[0], np.floor(right).astype('int32'))
+                
+                dir_save_path = "img_crop"
+                if not os.path.exists(dir_save_path):
+                    os.makedirs(dir_save_path)
+                crop_image = image.crop([left, top, right, bottom])
+                crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
+                print("save crop_" + str(i) + ".png to " + dir_save_path)
+        #---------------------------------------------------------#
+        #   图像绘制
+        #---------------------------------------------------------#
+        for i, c in list(enumerate(top_label)):
+            predicted_class = self.class_names[int(c)]
+            box             = top_boxes[i]
+            score           = top_conf[i]
+
+            top, left, bottom, right = box
+
+            top     = max(0, np.floor(top).astype('int32'))
+            left    = max(0, np.floor(left).astype('int32'))
+            bottom  = min(image.size[1], np.floor(bottom).astype('int32'))
+            right   = min(image.size[0], np.floor(right).astype('int32'))
+
+            label = '{} {:.2f}'.format(predicted_class, score)
+            draw = ImageDraw.Draw(image)
+            label_size = draw.textsize(label, font)
+            label = label.encode('utf-8')
+            # print(label, top, left, bottom, right)
+            
+            if top - label_size[1] >= 0:
+                text_origin = np.array([left, top - label_size[1]])
+            else:
+                text_origin = np.array([left, top + 1])
+
+            for i in range(thickness):
+                draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
+            draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
+            draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
+            del draw
+
+        return image
+
+    def get_FPS(self, image, test_interval):
+        #---------------------------------------------------#
+        #   计算输入图片的高和宽
+        #---------------------------------------------------#
+        image_shape = np.array(np.shape(image)[0:2])
+        input_shape = get_new_img_size(image_shape[0], image_shape[1])
+        #---------------------------------------------------------#
+        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
+        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+        #---------------------------------------------------------#
+        image       = cvtColor(image)
+        
+        #---------------------------------------------------------#
+        #   给原图像进行resize,resize到短边为600的大小上
+        #---------------------------------------------------------#
+        image_data  = resize_image(image, [input_shape[1], input_shape[0]])
+        #---------------------------------------------------------#
+        #   添加上batch_size维度
+        #---------------------------------------------------------#
+        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
+
+        with torch.no_grad():
+            images = torch.from_numpy(image_data)
+            if self.cuda:
+                images = images.cuda()
+
+            roi_cls_locs, roi_scores, rois, _ = self.net(images)
+            #-------------------------------------------------------------#
+            #   利用classifier的预测结果对建议框进行解码,获得预测框
+            #-------------------------------------------------------------#
+            results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape, 
+                                                    nms_iou = self.nms_iou, confidence = self.confidence)
+        t1 = time.time()
+        for _ in range(test_interval):
+            with torch.no_grad():
+                roi_cls_locs, roi_scores, rois, _ = self.net(images)
+                #-------------------------------------------------------------#
+                #   利用classifier的预测结果对建议框进行解码,获得预测框
+                #-------------------------------------------------------------#
+                results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape, 
+                                                        nms_iou = self.nms_iou, confidence = self.confidence)
+                
+        t2 = time.time()
+        tact_time = (t2 - t1) / test_interval
+        return tact_time
+
+    #---------------------------------------------------#
+    #   检测图片
+    #---------------------------------------------------#
+    def get_map_txt(self, image_id, image, class_names, map_out_path):
+        f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
+        #---------------------------------------------------#
+        #   计算输入图片的高和宽
+        #---------------------------------------------------#
+        image_shape = np.array(np.shape(image)[0:2])
+        input_shape = get_new_img_size(image_shape[0], image_shape[1])
+        #---------------------------------------------------------#
+        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
+        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+        #---------------------------------------------------------#
+        image       = cvtColor(image)
+        
+        #---------------------------------------------------------#
+        #   给原图像进行resize,resize到短边为600的大小上
+        #---------------------------------------------------------#
+        image_data  = resize_image(image, [input_shape[1], input_shape[0]])
+        #---------------------------------------------------------#
+        #   添加上batch_size维度
+        #---------------------------------------------------------#
+        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
+
+        with torch.no_grad():
+            images = torch.from_numpy(image_data)
+            if self.cuda:
+                images = images.cuda()
+
+            roi_cls_locs, roi_scores, rois, _ = self.net(images)
+            #-------------------------------------------------------------#
+            #   利用classifier的预测结果对建议框进行解码,获得预测框
+            #-------------------------------------------------------------#
+            results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape, 
+                                                    nms_iou = self.nms_iou, confidence = self.confidence)
+            #--------------------------------------#
+            #   如果没有检测到物体,则返回原图
+            #--------------------------------------#
+            if len(results[0]) <= 0:
+                return 
+
+            top_label   = np.array(results[0][:, 5], dtype = 'int32')
+            top_conf    = results[0][:, 4]
+            top_boxes   = results[0][:, :4]
+        
+        for i, c in list(enumerate(top_label)):
+            predicted_class = self.class_names[int(c)]
+            box             = top_boxes[i]
+            score           = str(top_conf[i])
+
+            top, left, bottom, right = box
+            if predicted_class not in class_names:
+                continue
+
+            f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
+
+        f.close()
+        return 

+ 140 - 0
get_map.py

@@ -0,0 +1,140 @@
+import os
+import xml.etree.ElementTree as ET
+
+from PIL import Image
+from tqdm import tqdm
+
+from utils.utils import get_classes
+from utils.utils_map import get_coco_map, get_map
+from frcnn import FRCNN
+
+if __name__ == "__main__":
+    '''
+    Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
+    默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
+
+    受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
+    因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
+    '''
+    #------------------------------------------------------------------------------------------------------------------#
+    #   map_mode用于指定该文件运行时计算的内容
+    #   map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
+    #   map_mode为1代表仅仅获得预测结果。
+    #   map_mode为2代表仅仅获得真实框。
+    #   map_mode为3代表仅仅计算VOC_map。
+    #   map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
+    #-------------------------------------------------------------------------------------------------------------------#
+    map_mode        = 0
+    #--------------------------------------------------------------------------------------#
+    #   此处的classes_path用于指定需要测量VOC_map的类别
+    #   一般情况下与训练和预测所用的classes_path一致即可
+    #--------------------------------------------------------------------------------------#
+    classes_path    = '/root/autodl-tmp/faster-rcnn-pytorch-master/model_data/voc_classes.txt'
+    #--------------------------------------------------------------------------------------#
+    #   MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
+    #   比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
+    #
+    #   当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
+    #   因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
+    #--------------------------------------------------------------------------------------#
+    MINOVERLAP      = 0.5
+    #--------------------------------------------------------------------------------------#
+    #   受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
+    #   因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
+    #   
+    #   该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
+    #   想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
+    #--------------------------------------------------------------------------------------#
+    confidence      = 0.02
+    #--------------------------------------------------------------------------------------#
+    #   预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
+    #   
+    #   该值一般不调整。
+    #--------------------------------------------------------------------------------------#
+    nms_iou         = 0.5
+    #---------------------------------------------------------------------------------------------------------------#
+    #   Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
+    #   
+    #   默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
+    #   因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
+    #   这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
+    #---------------------------------------------------------------------------------------------------------------#
+    score_threhold  = 0.5
+    #-------------------------------------------------------#
+    #   map_vis用于指定是否开启VOC_map计算的可视化
+    #-------------------------------------------------------#
+    map_vis         = False
+    #-------------------------------------------------------#
+    #   指向VOC数据集所在的文件夹
+    #   默认指向根目录下的VOC数据集
+    #-------------------------------------------------------#
+    VOCdevkit_path  = '/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007'
+    #-------------------------------------------------------#
+    #   结果输出的文件夹,默认为map_out
+    #-------------------------------------------------------#
+    map_out_path    = 'map_out'
+    path_temp = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007/ImageSets/Main/test.txt"
+    image_ids = open(path_temp).read().strip().split()
+
+    if not os.path.exists(map_out_path):
+        os.makedirs(map_out_path)
+    if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
+        os.makedirs(os.path.join(map_out_path, 'ground-truth'))
+    if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
+        os.makedirs(os.path.join(map_out_path, 'detection-results'))
+    if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
+        os.makedirs(os.path.join(map_out_path, 'images-optional'))
+
+    class_names, _ = get_classes(classes_path)
+
+    if map_mode == 0 or map_mode == 1:
+        print("Load model.")
+        frcnn = FRCNN(confidence = confidence, nms_iou = nms_iou)
+        print("Load model done.")
+
+        print("Get predict result.")
+        for image_id in tqdm(image_ids):
+            img_path = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007/JPEGImages/"
+            image_path  = os.path.join(img_path + image_id+".jpg")
+            image       = Image.open(image_path)
+            if map_vis:
+                image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
+            frcnn.get_map_txt(image_id, image, class_names, map_out_path)
+        print("Get predict result done.")
+        
+    if map_mode == 0 or map_mode == 2:
+        print("Get ground truth result.")
+        for image_id in tqdm(image_ids):
+            with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
+                root_path = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007/Annotations/"
+                root = ET.parse(os.path.join(root_path + image_id+".xml")).getroot()
+                for obj in root.findall('object'):
+                    difficult_flag = False
+                    if obj.find('difficult')!=None:
+                        difficult = obj.find('difficult').text
+                        if int(difficult)==1:
+                            difficult_flag = True
+                    obj_name = obj.find('name').text
+                    if obj_name not in class_names:
+                        continue
+                    bndbox  = obj.find('bndbox')
+                    left    = bndbox.find('xmin').text
+                    top     = bndbox.find('ymin').text
+                    right   = bndbox.find('xmax').text
+                    bottom  = bndbox.find('ymax').text
+
+                    if difficult_flag:
+                        new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
+                    else:
+                        new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
+        print("Get ground truth result done.")
+
+    if map_mode == 0 or map_mode == 3:
+        print("Get map.")
+        get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
+        print("Get map done.")
+
+    if map_mode == 4:
+        print("Get map.")
+        get_coco_map(class_names = class_names, path = map_out_path)
+        print("Get map done.")

+ 1 - 0
nets/__init__.py

@@ -0,0 +1 @@
+#

+ 119 - 0
nets/classifier.py

@@ -0,0 +1,119 @@
+import warnings
+
+import torch
+from torch import nn
+from torchvision.ops import RoIPool
+
+warnings.filterwarnings("ignore")
+
+class VGG16RoIHead(nn.Module):
+    def __init__(self, n_class, roi_size, spatial_scale, classifier):
+        super(VGG16RoIHead, self).__init__()
+        self.classifier = classifier
+        #--------------------------------------#
+        #   对ROIPooling后的的结果进行回归预测
+        #--------------------------------------#
+        self.cls_loc    = nn.Linear(4096, n_class * 4)
+        #-----------------------------------#
+        #   对ROIPooling后的的结果进行分类
+        #-----------------------------------#
+        self.score      = nn.Linear(4096, n_class)
+        #-----------------------------------#
+        #   权值初始化
+        #-----------------------------------#
+        normal_init(self.cls_loc, 0, 0.001)
+        normal_init(self.score, 0, 0.01)
+
+        self.roi = RoIPool((roi_size, roi_size), spatial_scale)
+        
+    def forward(self, x, rois, roi_indices, img_size):
+        n, _, _, _ = x.shape
+        if x.is_cuda:
+            roi_indices = roi_indices.cuda()
+            rois = rois.cuda()
+        rois        = torch.flatten(rois, 0, 1)
+        roi_indices = torch.flatten(roi_indices, 0, 1)
+
+        rois_feature_map = torch.zeros_like(rois)
+        rois_feature_map[:, [0, 2]] = rois[:, [0, 2]] / img_size[1] * x.size()[3]
+        rois_feature_map[:, [1, 3]] = rois[:, [1, 3]] / img_size[0] * x.size()[2]
+
+        indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim = 1)
+        #-----------------------------------#
+        #   利用建议框对公用特征层进行截取
+        #-----------------------------------#
+        pool = self.roi(x, indices_and_rois)
+        #-----------------------------------#
+        #   利用classifier网络进行特征提取
+        #-----------------------------------#
+        pool = pool.view(pool.size(0), -1)
+        #--------------------------------------------------------------#
+        #   当输入为一张图片的时候,这里获得的f7的shape为[300, 4096]
+        #--------------------------------------------------------------#
+        fc7 = self.classifier(pool)
+
+        roi_cls_locs    = self.cls_loc(fc7)
+        roi_scores      = self.score(fc7)
+
+        roi_cls_locs    = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
+        roi_scores      = roi_scores.view(n, -1, roi_scores.size(1))
+        return roi_cls_locs, roi_scores
+
+class Resnet50RoIHead(nn.Module):
+    def __init__(self, n_class, roi_size, spatial_scale, classifier):
+        super(Resnet50RoIHead, self).__init__()
+        self.classifier = classifier
+        #--------------------------------------#
+        #   对ROIPooling后的的结果进行回归预测
+        #--------------------------------------#
+        self.cls_loc = nn.Linear(2048, n_class * 4)
+        #-----------------------------------#
+        #   对ROIPooling后的的结果进行分类
+        #-----------------------------------#
+        self.score = nn.Linear(2048, n_class)
+        #-----------------------------------#
+        #   权值初始化
+        #-----------------------------------#
+        normal_init(self.cls_loc, 0, 0.001)
+        normal_init(self.score, 0, 0.01)
+
+        self.roi = RoIPool((roi_size, roi_size), spatial_scale)
+
+    def forward(self, x, rois, roi_indices, img_size):
+        n, _, _, _ = x.shape
+        if x.is_cuda:
+            roi_indices = roi_indices.cuda()
+            rois = rois.cuda()
+        rois        = torch.flatten(rois, 0, 1)
+        roi_indices = torch.flatten(roi_indices, 0, 1)
+        
+        rois_feature_map = torch.zeros_like(rois)
+        rois_feature_map[:, [0, 2]] = rois[:, [0, 2]] / img_size[1] * x.size()[3]
+        rois_feature_map[:, [1, 3]] = rois[:, [1, 3]] / img_size[0] * x.size()[2]
+
+        indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim = 1)
+        #-----------------------------------#
+        #   利用建议框对公用特征层进行截取
+        #-----------------------------------#
+        pool = self.roi(x, indices_and_rois)
+        #-----------------------------------#
+        #   利用classifier网络进行特征提取
+        #-----------------------------------#
+        fc7 = self.classifier(pool)
+        #--------------------------------------------------------------#
+        #   当输入为一张图片的时候,这里获得的f7的shape为[300, 2048]
+        #--------------------------------------------------------------#
+        fc7 = fc7.view(fc7.size(0), -1)
+
+        roi_cls_locs    = self.cls_loc(fc7)
+        roi_scores      = self.score(fc7)
+        roi_cls_locs    = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
+        roi_scores      = roi_scores.view(n, -1, roi_scores.size(1))
+        return roi_cls_locs, roi_scores
+
+def normal_init(m, mean, stddev, truncated = False):
+    if truncated:
+        m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean)  # not a perfect approximation
+    else:
+        m.weight.data.normal_(mean, stddev)
+        m.bias.data.zero_()

+ 109 - 0
nets/frcnn.py

@@ -0,0 +1,109 @@
+import torch.nn as nn
+from nets.classifier import Resnet50RoIHead, VGG16RoIHead
+from nets.resnet50 import resnet50
+from nets.rpn import RegionProposalNetwork
+from nets.vgg16 import decom_vgg16
+
+
+class FasterRCNN(nn.Module):
+    def __init__(self,  num_classes,  
+                    mode = "training",
+                    feat_stride = 16,
+                    anchor_scales = [8, 16, 32],
+                    ratios = [0.5, 1, 2],
+                    backbone = 'vgg',
+                    pretrained = False):
+        super(FasterRCNN, self).__init__()
+        self.feat_stride = feat_stride
+        #---------------------------------#
+        #   一共存在两个主干
+        #   vgg和resnet50
+        #---------------------------------#
+        if backbone == 'vgg':
+            self.extractor, classifier = decom_vgg16(pretrained)
+            #---------------------------------#
+            #   构建建议框网络
+            #---------------------------------#
+            self.rpn = RegionProposalNetwork(
+                512, 512,
+                ratios          = ratios,
+                anchor_scales   = anchor_scales,
+                feat_stride     = self.feat_stride,
+                mode            = mode
+            )
+            #---------------------------------#
+            #   构建分类器网络
+            #---------------------------------#
+            self.head = VGG16RoIHead(
+                n_class         = num_classes + 1,
+                roi_size        = 7,
+                spatial_scale   = 1,
+                classifier      = classifier
+            )
+        elif backbone == 'resnet50':
+            self.extractor, classifier = resnet50(pretrained)
+            #---------------------------------#
+            #   构建classifier网络
+            #---------------------------------#
+            self.rpn = RegionProposalNetwork(
+                1024, 512,
+                ratios          = ratios,
+                anchor_scales   = anchor_scales,
+                feat_stride     = self.feat_stride,
+                mode            = mode
+            )
+            #---------------------------------#
+            #   构建classifier网络
+            #---------------------------------#
+            self.head = Resnet50RoIHead(
+                n_class         = num_classes + 1,
+                roi_size        = 14,
+                spatial_scale   = 1,
+                classifier      = classifier
+            )
+            
+    def forward(self, x, scale=1., mode="forward"):
+        if mode == "forward":
+            #---------------------------------#
+            #   计算输入图片的大小
+            #---------------------------------#
+            img_size        = x.shape[2:]
+            #---------------------------------#
+            #   利用主干网络提取特征
+            #---------------------------------#
+            base_feature    = self.extractor.forward(x)
+
+            #---------------------------------#
+            #   获得建议框
+            #---------------------------------#
+            _, _, rois, roi_indices, _  = self.rpn.forward(base_feature, img_size, scale)
+            #---------------------------------------#
+            #   获得classifier的分类结果和回归结果
+            #---------------------------------------#
+            roi_cls_locs, roi_scores    = self.head.forward(base_feature, rois, roi_indices, img_size)
+            return roi_cls_locs, roi_scores, rois, roi_indices
+        elif mode == "extractor":
+            #---------------------------------#
+            #   利用主干网络提取特征
+            #---------------------------------#
+            base_feature    = self.extractor.forward(x)
+            return base_feature
+        elif mode == "rpn":
+            base_feature, img_size = x
+            #---------------------------------#
+            #   获得建议框
+            #---------------------------------#
+            rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn.forward(base_feature, img_size, scale)
+            return rpn_locs, rpn_scores, rois, roi_indices, anchor
+        elif mode == "head":
+            base_feature, rois, roi_indices, img_size = x
+            #---------------------------------------#
+            #   获得classifier的分类结果和回归结果
+            #---------------------------------------#
+            roi_cls_locs, roi_scores    = self.head.forward(base_feature, rois, roi_indices, img_size)
+            return roi_cls_locs, roi_scores
+
+    def freeze_bn(self):
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eval()

+ 393 - 0
nets/frcnn_training.py

@@ -0,0 +1,393 @@
+import math
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+def bbox_iou(bbox_a, bbox_b):
+    if bbox_a.shape[1] != 4 or bbox_b.shape[1] != 4:
+        print(bbox_a, bbox_b)
+        raise IndexError
+    tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2])
+    br = np.minimum(bbox_a[:, None, 2:], bbox_b[:, 2:])
+    area_i = np.prod(br - tl, axis=2) * (tl < br).all(axis=2)
+    area_a = np.prod(bbox_a[:, 2:] - bbox_a[:, :2], axis=1)
+    area_b = np.prod(bbox_b[:, 2:] - bbox_b[:, :2], axis=1)
+    return area_i / (area_a[:, None] + area_b - area_i)
+
+def bbox2loc(src_bbox, dst_bbox):
+    width = src_bbox[:, 2] - src_bbox[:, 0]
+    height = src_bbox[:, 3] - src_bbox[:, 1]
+    ctr_x = src_bbox[:, 0] + 0.5 * width
+    ctr_y = src_bbox[:, 1] + 0.5 * height
+
+    base_width = dst_bbox[:, 2] - dst_bbox[:, 0]
+    base_height = dst_bbox[:, 3] - dst_bbox[:, 1]
+    base_ctr_x = dst_bbox[:, 0] + 0.5 * base_width
+    base_ctr_y = dst_bbox[:, 1] + 0.5 * base_height
+
+    eps = np.finfo(height.dtype).eps
+    width = np.maximum(width, eps)
+    height = np.maximum(height, eps)
+
+    dx = (base_ctr_x - ctr_x) / width
+    dy = (base_ctr_y - ctr_y) / height
+    dw = np.log(base_width / width)
+    dh = np.log(base_height / height)
+
+    loc = np.vstack((dx, dy, dw, dh)).transpose()
+    return loc
+
+class AnchorTargetCreator(object):
+    def __init__(self, n_sample=256, pos_iou_thresh=0.7, neg_iou_thresh=0.3, pos_ratio=0.5):
+        self.n_sample       = n_sample
+        self.pos_iou_thresh = pos_iou_thresh
+        self.neg_iou_thresh = neg_iou_thresh
+        self.pos_ratio      = pos_ratio
+
+    def __call__(self, bbox, anchor):
+        argmax_ious, label = self._create_label(anchor, bbox)
+        if (label > 0).any():
+            loc = bbox2loc(anchor, bbox[argmax_ious])
+            return loc, label
+        else:
+            return np.zeros_like(anchor), label
+
+    def _calc_ious(self, anchor, bbox):
+        #----------------------------------------------#
+        #   anchor和bbox的iou
+        #   获得的ious的shape为[num_anchors, num_gt]
+        #----------------------------------------------#
+        ious = bbox_iou(anchor, bbox)
+
+        if len(bbox)==0:
+            return np.zeros(len(anchor), np.int32), np.zeros(len(anchor)), np.zeros(len(bbox))
+        #---------------------------------------------------------#
+        #   获得每一个先验框最对应的真实框  [num_anchors, ]
+        #---------------------------------------------------------#
+        argmax_ious = ious.argmax(axis=1)
+        #---------------------------------------------------------#
+        #   找出每一个先验框最对应的真实框的iou  [num_anchors, ]
+        #---------------------------------------------------------#
+        max_ious = np.max(ious, axis=1)
+        #---------------------------------------------------------#
+        #   获得每一个真实框最对应的先验框  [num_gt, ]
+        #---------------------------------------------------------#
+        gt_argmax_ious = ious.argmax(axis=0)
+        #---------------------------------------------------------#
+        #   保证每一个真实框都存在对应的先验框
+        #---------------------------------------------------------#
+        for i in range(len(gt_argmax_ious)):
+            argmax_ious[gt_argmax_ious[i]] = i
+
+        return argmax_ious, max_ious, gt_argmax_ious
+        
+    def _create_label(self, anchor, bbox):
+        # ------------------------------------------ #
+        #   1是正样本,0是负样本,-1忽略
+        #   初始化的时候全部设置为-1
+        # ------------------------------------------ #
+        label = np.empty((len(anchor),), dtype=np.int32)
+        label.fill(-1)
+
+        # ------------------------------------------------------------------------ #
+        #   argmax_ious为每个先验框对应的最大的真实框的序号         [num_anchors, ]
+        #   max_ious为每个真实框对应的最大的真实框的iou             [num_anchors, ]
+        #   gt_argmax_ious为每一个真实框对应的最大的先验框的序号    [num_gt, ]
+        # ------------------------------------------------------------------------ #
+        argmax_ious, max_ious, gt_argmax_ious = self._calc_ious(anchor, bbox)
+        
+        # ----------------------------------------------------- #
+        #   如果小于门限值则设置为负样本
+        #   如果大于门限值则设置为正样本
+        #   每个真实框至少对应一个先验框
+        # ----------------------------------------------------- #
+        label[max_ious < self.neg_iou_thresh] = 0
+        label[max_ious >= self.pos_iou_thresh] = 1
+        if len(gt_argmax_ious)>0:
+            label[gt_argmax_ious] = 1
+
+        # ----------------------------------------------------- #
+        #   判断正样本数量是否大于128,如果大于则限制在128
+        # ----------------------------------------------------- #
+        n_pos = int(self.pos_ratio * self.n_sample)
+        pos_index = np.where(label == 1)[0]
+        if len(pos_index) > n_pos:
+            disable_index = np.random.choice(pos_index, size=(len(pos_index) - n_pos), replace=False)
+            label[disable_index] = -1
+
+        # ----------------------------------------------------- #
+        #   平衡正负样本,保持总数量为256
+        # ----------------------------------------------------- #
+        n_neg = self.n_sample - np.sum(label == 1)
+        neg_index = np.where(label == 0)[0]
+        if len(neg_index) > n_neg:
+            disable_index = np.random.choice(neg_index, size=(len(neg_index) - n_neg), replace=False)
+            label[disable_index] = -1
+
+        return argmax_ious, label
+
+
+class ProposalTargetCreator(object):
+    def __init__(self, n_sample=128, pos_ratio=0.5, pos_iou_thresh=0.5, neg_iou_thresh_high=0.5, neg_iou_thresh_low=0):
+        self.n_sample = n_sample
+        self.pos_ratio = pos_ratio
+        self.pos_roi_per_image = np.round(self.n_sample * self.pos_ratio)
+        self.pos_iou_thresh = pos_iou_thresh
+        self.neg_iou_thresh_high = neg_iou_thresh_high
+        self.neg_iou_thresh_low = neg_iou_thresh_low
+
+    def __call__(self, roi, bbox, label, loc_normalize_std=(0.1, 0.1, 0.2, 0.2)):
+        roi = np.concatenate((roi.detach().cpu().numpy(), bbox), axis=0)
+        # ----------------------------------------------------- #
+        #   计算建议框和真实框的重合程度
+        # ----------------------------------------------------- #
+        iou = bbox_iou(roi, bbox)
+        
+        if len(bbox)==0:
+            gt_assignment = np.zeros(len(roi), np.int32)
+            max_iou = np.zeros(len(roi))
+            gt_roi_label = np.zeros(len(roi))
+        else:
+            #---------------------------------------------------------#
+            #   获得每一个建议框最对应的真实框  [num_roi, ]
+            #---------------------------------------------------------#
+            gt_assignment = iou.argmax(axis=1)
+            #---------------------------------------------------------#
+            #   获得每一个建议框最对应的真实框的iou  [num_roi, ]
+            #---------------------------------------------------------#
+            max_iou = iou.max(axis=1)
+            #---------------------------------------------------------#
+            #   真实框的标签要+1因为有背景的存在
+            #---------------------------------------------------------#
+            gt_roi_label = label[gt_assignment] + 1
+
+        #----------------------------------------------------------------#
+        #   满足建议框和真实框重合程度大于neg_iou_thresh_high的作为负样本
+        #   将正样本的数量限制在self.pos_roi_per_image以内
+        #----------------------------------------------------------------#
+        pos_index = np.where(max_iou >= self.pos_iou_thresh)[0]
+        pos_roi_per_this_image = int(min(self.pos_roi_per_image, pos_index.size))
+        if pos_index.size > 0:
+            pos_index = np.random.choice(pos_index, size=pos_roi_per_this_image, replace=False)
+
+        #-----------------------------------------------------------------------------------------------------#
+        #   满足建议框和真实框重合程度小于neg_iou_thresh_high大于neg_iou_thresh_low作为负样本
+        #   将正样本的数量和负样本的数量的总和固定成self.n_sample
+        #-----------------------------------------------------------------------------------------------------#
+        neg_index = np.where((max_iou < self.neg_iou_thresh_high) & (max_iou >= self.neg_iou_thresh_low))[0]
+        neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image
+        neg_roi_per_this_image = int(min(neg_roi_per_this_image, neg_index.size))
+        if neg_index.size > 0:
+            neg_index = np.random.choice(neg_index, size=neg_roi_per_this_image, replace=False)
+            
+        #---------------------------------------------------------#
+        #   sample_roi      [n_sample, ]
+        #   gt_roi_loc      [n_sample, 4]
+        #   gt_roi_label    [n_sample, ]
+        #---------------------------------------------------------#
+        keep_index = np.append(pos_index, neg_index)
+
+        sample_roi = roi[keep_index]
+        if len(bbox)==0:
+            return sample_roi, np.zeros_like(sample_roi), gt_roi_label[keep_index]
+
+        gt_roi_loc = bbox2loc(sample_roi, bbox[gt_assignment[keep_index]])
+        gt_roi_loc = (gt_roi_loc / np.array(loc_normalize_std, np.float32))
+
+        gt_roi_label = gt_roi_label[keep_index]
+        gt_roi_label[pos_roi_per_this_image:] = 0
+        return sample_roi, gt_roi_loc, gt_roi_label
+
+class FasterRCNNTrainer(nn.Module):
+    def __init__(self, model_train, optimizer):
+        super(FasterRCNNTrainer, self).__init__()
+        self.model_train    = model_train
+        self.optimizer      = optimizer
+
+        self.rpn_sigma      = 1
+        self.roi_sigma      = 1
+
+        self.anchor_target_creator      = AnchorTargetCreator()
+        self.proposal_target_creator    = ProposalTargetCreator()
+
+        self.loc_normalize_std          = [0.1, 0.1, 0.2, 0.2]
+
+    def _fast_rcnn_loc_loss(self, pred_loc, gt_loc, gt_label, sigma):
+        pred_loc    = pred_loc[gt_label > 0]
+        gt_loc      = gt_loc[gt_label > 0]
+
+        sigma_squared = sigma ** 2
+        regression_diff = (gt_loc - pred_loc)
+        regression_diff = regression_diff.abs().float()
+        regression_loss = torch.where(
+                regression_diff < (1. / sigma_squared),
+                0.5 * sigma_squared * regression_diff ** 2,
+                regression_diff - 0.5 / sigma_squared
+            )
+        regression_loss = regression_loss.sum()
+        num_pos         = (gt_label > 0).sum().float()
+        
+        regression_loss /= torch.max(num_pos, torch.ones_like(num_pos))
+        return regression_loss
+        
+    def forward(self, imgs, bboxes, labels, scale):
+        n           = imgs.shape[0]
+        img_size    = imgs.shape[2:]
+        #-------------------------------#
+        #   获取公用特征层
+        #-------------------------------#
+        base_feature = self.model_train(imgs, mode = 'extractor')
+
+        # -------------------------------------------------- #
+        #   利用rpn网络获得调整参数、得分、建议框、先验框
+        # -------------------------------------------------- #
+        rpn_locs, rpn_scores, rois, roi_indices, anchor = self.model_train(x = [base_feature, img_size], scale = scale, mode = 'rpn')
+        
+        rpn_loc_loss_all, rpn_cls_loss_all, roi_loc_loss_all, roi_cls_loss_all  = 0, 0, 0, 0
+        sample_rois, sample_indexes, gt_roi_locs, gt_roi_labels                 = [], [], [], []
+        for i in range(n):
+            bbox        = bboxes[i]
+            label       = labels[i]
+            rpn_loc     = rpn_locs[i]
+            rpn_score   = rpn_scores[i]
+            roi         = rois[i]
+            # -------------------------------------------------- #
+            #   利用真实框和先验框获得建议框网络应该有的预测结果
+            #   给每个先验框都打上标签
+            #   gt_rpn_loc      [num_anchors, 4]
+            #   gt_rpn_label    [num_anchors, ]
+            # -------------------------------------------------- #
+            gt_rpn_loc, gt_rpn_label    = self.anchor_target_creator(bbox, anchor[0].cpu().numpy())
+            gt_rpn_loc                  = torch.Tensor(gt_rpn_loc).type_as(rpn_locs)
+            gt_rpn_label                = torch.Tensor(gt_rpn_label).type_as(rpn_locs).long()
+            # -------------------------------------------------- #
+            #   分别计算建议框网络的回归损失和分类损失
+            # -------------------------------------------------- #
+            rpn_loc_loss = self._fast_rcnn_loc_loss(rpn_loc, gt_rpn_loc, gt_rpn_label, self.rpn_sigma)
+            rpn_cls_loss = F.cross_entropy(rpn_score, gt_rpn_label, ignore_index=-1)
+  
+            rpn_loc_loss_all += rpn_loc_loss
+            rpn_cls_loss_all += rpn_cls_loss
+            # ------------------------------------------------------ #
+            #   利用真实框和建议框获得classifier网络应该有的预测结果
+            #   获得三个变量,分别是sample_roi, gt_roi_loc, gt_roi_label
+            #   sample_roi      [n_sample, ]
+            #   gt_roi_loc      [n_sample, 4]
+            #   gt_roi_label    [n_sample, ]
+            # ------------------------------------------------------ #
+            sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(roi, bbox, label, self.loc_normalize_std)
+            sample_rois.append(torch.Tensor(sample_roi).type_as(rpn_locs))
+            sample_indexes.append(torch.ones(len(sample_roi)).type_as(rpn_locs) * roi_indices[i][0])
+            gt_roi_locs.append(torch.Tensor(gt_roi_loc).type_as(rpn_locs))
+            gt_roi_labels.append(torch.Tensor(gt_roi_label).type_as(rpn_locs).long())
+            
+        sample_rois     = torch.stack(sample_rois, dim=0)
+        sample_indexes  = torch.stack(sample_indexes, dim=0)
+        roi_cls_locs, roi_scores = self.model_train([base_feature, sample_rois, sample_indexes, img_size], mode = 'head')
+        for i in range(n):
+            # ------------------------------------------------------ #
+            #   根据建议框的种类,取出对应的回归预测结果
+            # ------------------------------------------------------ #
+            n_sample = roi_cls_locs.size()[1]
+            
+            roi_cls_loc     = roi_cls_locs[i]
+            roi_score       = roi_scores[i]
+            gt_roi_loc      = gt_roi_locs[i]
+            gt_roi_label    = gt_roi_labels[i]
+            
+            roi_cls_loc = roi_cls_loc.view(n_sample, -1, 4)
+            roi_loc     = roi_cls_loc[torch.arange(0, n_sample), gt_roi_label]
+
+            # -------------------------------------------------- #
+            #   分别计算Classifier网络的回归损失和分类损失
+            # -------------------------------------------------- #
+            roi_loc_loss = self._fast_rcnn_loc_loss(roi_loc, gt_roi_loc, gt_roi_label.data, self.roi_sigma)
+            roi_cls_loss = nn.CrossEntropyLoss()(roi_score, gt_roi_label)
+
+            roi_loc_loss_all += roi_loc_loss
+            roi_cls_loss_all += roi_cls_loss
+            
+        losses = [rpn_loc_loss_all/n, rpn_cls_loss_all/n, roi_loc_loss_all/n, roi_cls_loss_all/n]
+        losses = losses + [sum(losses)]
+        return losses
+
+    def train_step(self, imgs, bboxes, labels, scale, fp16=False, scaler=None):
+        self.optimizer.zero_grad()
+        if not fp16:
+            losses = self.forward(imgs, bboxes, labels, scale)
+            losses[-1].backward()
+            self.optimizer.step()
+        else:
+            from torch.cuda.amp import autocast
+            with autocast():
+                losses = self.forward(imgs, bboxes, labels, scale)
+
+            #----------------------#
+            #   反向传播
+            #----------------------#
+            scaler.scale(losses[-1]).backward()
+            scaler.step(self.optimizer)
+            scaler.update()
+            
+        return losses
+
+def weights_init(net, init_type='normal', init_gain=0.02):
+    def init_func(m):
+        classname = m.__class__.__name__
+        if hasattr(m, 'weight') and classname.find('Conv') != -1:
+            if init_type == 'normal':
+                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
+            elif init_type == 'xavier':
+                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
+            elif init_type == 'kaiming':
+                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+            elif init_type == 'orthogonal':
+                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
+            else:
+                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+        elif classname.find('BatchNorm2d') != -1:
+            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
+            torch.nn.init.constant_(m.bias.data, 0.0)
+    print('initialize network with %s type' % init_type)
+    net.apply(init_func)
+
+def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
+    def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
+        if iters <= warmup_total_iters:
+            # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
+            lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
+        elif iters >= total_iters - no_aug_iter:
+            lr = min_lr
+        else:
+            lr = min_lr + 0.5 * (lr - min_lr) * (
+                1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
+            )
+        return lr
+
+    def step_lr(lr, decay_rate, step_size, iters):
+        if step_size < 1:
+            raise ValueError("step_size must above 1.")
+        n       = iters // step_size
+        out_lr  = lr * decay_rate ** n
+        return out_lr
+
+    if lr_decay_type == "cos":
+        warmup_total_iters  = min(max(warmup_iters_ratio * total_iters, 1), 3)
+        warmup_lr_start     = max(warmup_lr_ratio * lr, 1e-6)
+        no_aug_iter         = min(max(no_aug_iter_ratio * total_iters, 1), 15)
+        func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
+    else:
+        decay_rate  = (min_lr / lr) ** (1 / (step_num - 1))
+        step_size   = total_iters / step_num
+        func = partial(step_lr, lr, decay_rate, step_size)
+
+    return func
+
+def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
+    lr = lr_scheduler_func(epoch)
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr

+ 130 - 0
nets/resnet50.py

@@ -0,0 +1,130 @@
+import math
+
+import torch.nn as nn
+from torch.hub import load_state_dict_from_url
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * 4)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+class ResNet(nn.Module):
+    def __init__(self, block, layers, num_classes=1000):
+        #-----------------------------------#
+        #   假设输入进来的图片是600,600,3
+        #-----------------------------------#
+        self.inplanes = 64
+        super(ResNet, self).__init__()
+
+        # 600,600,3 -> 300,300,64
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+
+        # 300,300,64 -> 150,150,64
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
+
+        # 150,150,64 -> 150,150,256
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        # 150,150,256 -> 75,75,512
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        # 75,75,512 -> 38,38,1024 到这里可以获得一个38,38,1024的共享特征层
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        # self.layer4被用在classifier模型中
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        
+        self.avgpool = nn.AvgPool2d(7)
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        #-------------------------------------------------------------------#
+        #   当模型需要进行高和宽的压缩的时候,就需要用到残差边的downsample
+        #-------------------------------------------------------------------#
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+        return x
+
+def resnet50(pretrained = False):
+    model = ResNet(Bottleneck, [3, 4, 6, 3])
+    if pretrained:
+        state_dict = load_state_dict_from_url("https://download.pytorch.org/models/resnet50-19c8e357.pth", model_dir="./model_data")
+        model.load_state_dict(state_dict)
+    #----------------------------------------------------------------------------#
+    #   获取特征提取部分,从conv1到model.layer3,最终获得一个38,38,1024的特征层
+    #----------------------------------------------------------------------------#
+    features    = list([model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2, model.layer3])
+    #----------------------------------------------------------------------------#
+    #   获取分类部分,从model.layer4到model.avgpool
+    #----------------------------------------------------------------------------#
+    classifier  = list([model.layer4, model.avgpool])
+    
+    features    = nn.Sequential(*features)
+    classifier  = nn.Sequential(*classifier)
+    return features, classifier

+ 191 - 0
nets/rpn.py

@@ -0,0 +1,191 @@
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.ops import nms
+from utils.anchors import _enumerate_shifted_anchor, generate_anchor_base
+from utils.utils_bbox import loc2bbox
+
+
+class ProposalCreator():
+    def __init__(
+        self, 
+        mode, 
+        nms_iou             = 0.7,
+        n_train_pre_nms     = 12000,
+        n_train_post_nms    = 600,
+        n_test_pre_nms      = 3000,
+        n_test_post_nms     = 300,
+        min_size            = 16
+    
+    ):
+        #-----------------------------------#
+        #   设置预测还是训练
+        #-----------------------------------#
+        self.mode               = mode
+        #-----------------------------------#
+        #   建议框非极大抑制的iou大小
+        #-----------------------------------#
+        self.nms_iou            = nms_iou
+        #-----------------------------------#
+        #   训练用到的建议框数量
+        #-----------------------------------#
+        self.n_train_pre_nms    = n_train_pre_nms
+        self.n_train_post_nms   = n_train_post_nms
+        #-----------------------------------#
+        #   预测用到的建议框数量
+        #-----------------------------------#
+        self.n_test_pre_nms     = n_test_pre_nms
+        self.n_test_post_nms    = n_test_post_nms
+        self.min_size           = min_size
+
+    def __call__(self, loc, score, anchor, img_size, scale=1.):
+        if self.mode == "training":
+            n_pre_nms   = self.n_train_pre_nms
+            n_post_nms  = self.n_train_post_nms
+        else:
+            n_pre_nms   = self.n_test_pre_nms
+            n_post_nms  = self.n_test_post_nms
+
+        #-----------------------------------#
+        #   将先验框转换成tensor
+        #-----------------------------------#
+        anchor = torch.from_numpy(anchor).type_as(loc)
+        #-----------------------------------#
+        #   将RPN网络预测结果转化成建议框
+        #-----------------------------------#
+        roi = loc2bbox(anchor, loc)
+        #-----------------------------------#
+        #   防止建议框超出图像边缘
+        #-----------------------------------#
+        roi[:, [0, 2]] = torch.clamp(roi[:, [0, 2]], min = 0, max = img_size[1])
+        roi[:, [1, 3]] = torch.clamp(roi[:, [1, 3]], min = 0, max = img_size[0])
+        
+        #-----------------------------------#
+        #   建议框的宽高的最小值不可以小于16
+        #-----------------------------------#
+        min_size    = self.min_size * scale
+        keep        = torch.where(((roi[:, 2] - roi[:, 0]) >= min_size) & ((roi[:, 3] - roi[:, 1]) >= min_size))[0]
+        #-----------------------------------#
+        #   将对应的建议框保留下来
+        #-----------------------------------#
+        roi         = roi[keep, :]
+        score       = score[keep]
+
+        #-----------------------------------#
+        #   根据得分进行排序,取出建议框
+        #-----------------------------------#
+        order       = torch.argsort(score, descending=True)
+        if n_pre_nms > 0:
+            order   = order[:n_pre_nms]
+        roi     = roi[order, :]
+        score   = score[order]
+
+        #-----------------------------------#
+        #   对建议框进行非极大抑制
+        #   使用官方的非极大抑制会快非常多
+        #-----------------------------------#
+        keep    = nms(roi, score, self.nms_iou)
+        if len(keep) < n_post_nms:
+            index_extra = np.random.choice(range(len(keep)), size=(n_post_nms - len(keep)), replace=True)
+            keep        = torch.cat([keep, keep[index_extra]])
+        keep    = keep[:n_post_nms]
+        roi     = roi[keep]
+        return roi
+
+
+class RegionProposalNetwork(nn.Module):
+    def __init__(
+        self, 
+        in_channels     = 512, 
+        mid_channels    = 512, 
+        ratios          = [0.5, 1, 2],
+        anchor_scales   = [8, 16, 32], 
+        feat_stride     = 16,
+        mode            = "training",
+    ):
+        super(RegionProposalNetwork, self).__init__()
+        #-----------------------------------------#
+        #   生成基础先验框,shape为[9, 4]
+        #-----------------------------------------#
+        self.anchor_base    = generate_anchor_base(anchor_scales = anchor_scales, ratios = ratios)
+        n_anchor            = self.anchor_base.shape[0]
+
+        #-----------------------------------------#
+        #   先进行一个3x3的卷积,可理解为特征整合
+        #-----------------------------------------#
+        self.conv1  = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
+        #-----------------------------------------#
+        #   分类预测先验框内部是否包含物体
+        #-----------------------------------------#
+        self.score  = nn.Conv2d(mid_channels, n_anchor * 2, 1, 1, 0)
+        #-----------------------------------------#
+        #   回归预测对先验框进行调整
+        #-----------------------------------------#
+        self.loc    = nn.Conv2d(mid_channels, n_anchor * 4, 1, 1, 0)
+
+        #-----------------------------------------#
+        #   特征点间距步长
+        #-----------------------------------------#
+        self.feat_stride    = feat_stride
+        #-----------------------------------------#
+        #   用于对建议框解码并进行非极大抑制
+        #-----------------------------------------#
+        self.proposal_layer = ProposalCreator(mode)
+        #--------------------------------------#
+        #   对FPN的网络部分进行权值初始化
+        #--------------------------------------#
+        normal_init(self.conv1, 0, 0.01)
+        normal_init(self.score, 0, 0.01)
+        normal_init(self.loc, 0, 0.01)
+
+    def forward(self, x, img_size, scale=1.):
+        n, _, h, w = x.shape
+        #-----------------------------------------#
+        #   先进行一个3x3的卷积,可理解为特征整合
+        #-----------------------------------------#
+        x = F.relu(self.conv1(x))
+        #-----------------------------------------#
+        #   回归预测对先验框进行调整
+        #-----------------------------------------#
+        rpn_locs = self.loc(x)
+        rpn_locs = rpn_locs.permute(0, 2, 3, 1).contiguous().view(n, -1, 4)
+        #-----------------------------------------#
+        #   分类预测先验框内部是否包含物体
+        #-----------------------------------------#
+        rpn_scores = self.score(x)
+        rpn_scores = rpn_scores.permute(0, 2, 3, 1).contiguous().view(n, -1, 2)
+        
+        #--------------------------------------------------------------------------------------#
+        #   进行softmax概率计算,每个先验框只有两个判别结果
+        #   内部包含物体或者内部不包含物体,rpn_softmax_scores[:, :, 1]的内容为包含物体的概率
+        #--------------------------------------------------------------------------------------#
+        rpn_softmax_scores  = F.softmax(rpn_scores, dim=-1)
+        rpn_fg_scores       = rpn_softmax_scores[:, :, 1].contiguous()
+        rpn_fg_scores       = rpn_fg_scores.view(n, -1)
+
+        #------------------------------------------------------------------------------------------------#
+        #   生成先验框,此时获得的anchor是布满网格点的,当输入图片为600,600,3的时候,shape为(12996, 4)
+        #------------------------------------------------------------------------------------------------#
+        anchor = _enumerate_shifted_anchor(np.array(self.anchor_base), self.feat_stride, h, w)
+        rois        = list()
+        roi_indices = list()
+        for i in range(n):
+            roi         = self.proposal_layer(rpn_locs[i], rpn_fg_scores[i], anchor, img_size, scale = scale)
+            batch_index = i * torch.ones((len(roi),))
+            rois.append(roi.unsqueeze(0))
+            roi_indices.append(batch_index.unsqueeze(0))
+
+        rois        = torch.cat(rois, dim=0).type_as(x)
+        roi_indices = torch.cat(roi_indices, dim=0).type_as(x)
+        anchor      = torch.from_numpy(anchor).unsqueeze(0).float().to(x.device)
+        
+        return rpn_locs, rpn_scores, rois, roi_indices, anchor
+
+def normal_init(m, mean, stddev, truncated=False):
+    if truncated:
+        m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean)  # not a perfect approximation
+    else:
+        m.weight.data.normal_(mean, stddev)
+        m.bias.data.zero_()

+ 111 - 0
nets/vgg16.py

@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+from torch.hub import load_state_dict_from_url
+
+
+#--------------------------------------#
+#   VGG16的结构
+#--------------------------------------#
+class VGG(nn.Module):
+    def __init__(self, features, num_classes=1000, init_weights=True):
+        super(VGG, self).__init__()
+        self.features = features
+        #--------------------------------------#
+        #   平均池化到7x7大小
+        #--------------------------------------#
+        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
+        #--------------------------------------#
+        #   分类部分
+        #--------------------------------------#
+        self.classifier = nn.Sequential(
+            nn.Linear(512 * 7 * 7, 4096),
+            nn.ReLU(True),
+            nn.Dropout(),
+            nn.Linear(4096, 4096),
+            nn.ReLU(True),
+            nn.Dropout(),
+            nn.Linear(4096, num_classes),
+        )
+        if init_weights:
+            self._initialize_weights()
+
+    def forward(self, x):
+        #--------------------------------------#
+        #   特征提取
+        #--------------------------------------#
+        x = self.features(x)
+        #--------------------------------------#
+        #   平均池化
+        #--------------------------------------#
+        x = self.avgpool(x)
+        #--------------------------------------#
+        #   平铺后
+        #--------------------------------------#
+        x = torch.flatten(x, 1)
+        #--------------------------------------#
+        #   分类部分
+        #--------------------------------------#
+        x = self.classifier(x)
+        return x
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+'''
+假设输入图像为(600, 600, 3),随着cfg的循环,特征层变化如下:
+600,600,3 -> 600,600,64 -> 600,600,64 -> 300,300,64 -> 300,300,128 -> 300,300,128 -> 150,150,128 -> 150,150,256 -> 150,150,256 -> 150,150,256 
+-> 75,75,256 -> 75,75,512 -> 75,75,512 -> 75,75,512 -> 37,37,512 ->  37,37,512 -> 37,37,512 -> 37,37,512
+到cfg结束,我们获得了一个37,37,512的特征层
+'''
+
+cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
+
+
+#--------------------------------------#
+#   特征提取部分
+#--------------------------------------#
+def make_layers(cfg, batch_norm = False):
+    layers = []
+    in_channels = 3
+    for v in cfg:
+        if v == 'M':
+            layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
+        else:
+            conv2d = nn.Conv2d(in_channels, v, kernel_size = 3, padding = 1)
+            if batch_norm:
+                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)]
+            else:
+                layers += [conv2d, nn.ReLU(inplace = True)]
+            in_channels = v 
+    return nn.Sequential(*layers)
+
+def decom_vgg16(pretrained = False):
+    model = VGG(make_layers(cfg))
+    if pretrained:
+        state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir = "./model_data")
+        model.load_state_dict(state_dict)
+    #----------------------------------------------------------------------------#
+    #   获取特征提取部分,最终获得一个37,37,1024的特征层
+    #----------------------------------------------------------------------------#
+    features    = list(model.features)[:30]
+    #----------------------------------------------------------------------------#
+    #   获取分类部分,需要除去Dropout部分
+    #----------------------------------------------------------------------------#
+    classifier  = list(model.classifier)
+    del classifier[6]
+    del classifier[5]
+    del classifier[2]
+
+    features    = nn.Sequential(*features)
+    classifier  = nn.Sequential(*classifier)
+    return features, classifier

+ 137 - 0
predict.py

@@ -0,0 +1,137 @@
+#----------------------------------------------------#
+#   将单张图片预测、摄像头检测和FPS测试功能
+#   整合到了一个py文件中,通过指定mode进行模式的修改。
+#----------------------------------------------------#
+import time
+import cv2
+import numpy as np
+from PIL import Image
+import os
+from tqdm import tqdm
+from frcnn import FRCNN
+
+
+if __name__ == "__main__":
+    frcnn = FRCNN()
+    #----------------------------------------------------------------------------------------------------------#
+    #   mode用于指定测试的模式:
+    #   'predict'           表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
+    #   'video'             表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
+    #   'fps'               表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
+    #   'dir_predict'       表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
+    #----------------------------------------------------------------------------------------------------------#
+    mode = "dir_predict"
+    #-------------------------------------------------------------------------#
+    #   crop                指定了是否在单张图片预测后对目标进行截取
+    #   count               指定了是否进行目标的计数
+    #   crop、count仅在mode='predict'时有效
+    #-------------------------------------------------------------------------#
+    crop            = False
+    count           = False
+    #----------------------------------------------------------------------------------------------------------#
+    #   video_path          用于指定视频的路径,当video_path=0时表示检测摄像头
+    #                       想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
+    #   video_save_path     表示视频保存的路径,当video_save_path=""时表示不保存
+    #                       想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
+    #   video_fps           用于保存的视频的fps
+    #
+    #   video_path、video_save_path和video_fps仅在mode='video'时有效
+    #   保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
+    #----------------------------------------------------------------------------------------------------------#
+    video_path      = 0
+    video_save_path = ""
+    video_fps       = 25.0
+    #----------------------------------------------------------------------------------------------------------#
+    #   test_interval       用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
+    #   fps_image_path      用于指定测试的fps图片
+    #   
+    #   test_interval和fps_image_path仅在mode='fps'有效
+    #----------------------------------------------------------------------------------------------------------#
+    test_interval   = 100
+    fps_image_path  = "img/street.jpg"
+    #-------------------------------------------------------------------------#
+    #   dir_origin_path     指定了用于检测的图片的文件夹路径
+    #   dir_save_path       指定了检测完图片的保存路径
+    #   
+    #   dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
+    #-------------------------------------------------------------------------#
+    dir_origin_path = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007_wm_val/JPEGImages/"
+    dir_save_path   = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007_wm_val/JPEGImages_out/"
+
+
+    if mode == "predict":
+        '''
+        1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
+        具体流程可以参考get_dr_txt.py,在get_dr_txt.py即实现了遍历还实现了目标信息的保存。
+        2、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 
+        3、如果想要获得预测框的坐标,可以进入frcnn.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
+        4、如果想要利用预测框截取下目标,可以进入frcnn.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
+        在原图上利用矩阵的方式进行截取。
+        5、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入frcnn.detect_image函数,在绘图部分对predicted_class进行判断,
+        比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
+        '''
+        while True:
+            img = input('Input image filename:')
+            try:
+                image = Image.open(img)
+            except:
+                print('Open Error! Try again!')
+            else:
+                r_image = frcnn.detect_image(image, crop = crop, count = count)
+                r_image.show()
+                
+
+    elif mode == "video":
+        capture = cv2.VideoCapture(video_path)
+        if video_save_path != "":
+            fourcc = cv2.VideoWriter_fourcc(*'XVID')
+            size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
+            out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
+        fps = 0.0
+        while(True):
+            t1 = time.time()
+            # 读取某一帧
+            ref,frame = capture.read()
+            # 格式转变,BGRtoRGB
+            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+            # 转变成Image
+            frame = Image.fromarray(np.uint8(frame))
+            # 进行检测
+            frame = np.array(frcnn.detect_image(frame))
+            # RGBtoBGR满足opencv显示格式
+            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+            fps  = ( fps + (1. / (time.time() - t1)) ) / 2
+            print("fps = %.2f"%(fps))
+            frame = cv2.putText(frame, "fps = %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+            cv2.imshow("video", frame)
+            c = cv2.waitKey(1) & 0xff 
+            if video_save_path != "":
+                out.write(frame)
+            if c == 27:
+                capture.release()
+                break
+        capture.release()
+        out.release()
+        cv2.destroyAllWindows()
+
+
+    elif mode == "fps":
+        img = Image.open(fps_image_path)
+        tact_time = frcnn.get_FPS(img, test_interval)
+        print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
+
+
+    elif mode == "dir_predict":
+        img_names = os.listdir(dir_origin_path)
+        for img_name in tqdm(img_names):
+            if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
+                image_path  = os.path.join(dir_origin_path, img_name)
+                image       = Image.open(image_path)
+                r_image     = frcnn.detect_image(image)
+                if not os.path.exists(dir_save_path):
+                    os.makedirs(dir_save_path)
+                r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality = 95, subsampling = 0)
+
+
+    else:
+        raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")

+ 11 - 0
requirements.txt

@@ -0,0 +1,11 @@
+python == 3.10.6
+numpy == 1.23.3
+opencv == 4.6.0
+pillow == 9.2.0
+pycocotools == 2.0.6
+pytorch == 1.12.1
+scipy == 1.9.3
+torchvision == 0.13.1
+tqdm == 4.64.1
+matplotlib == 3.6.2
+hdf5 == 1.12.1

+ 29 - 0
summary.py

@@ -0,0 +1,29 @@
+#--------------------------------------------#
+#   该部分代码用于看网络结构
+#--------------------------------------------#
+import torch
+from thop import clever_format, profile
+from torchsummary import summary
+
+from nets.frcnn import FasterRCNN
+
+if __name__ == "__main__":
+    input_shape     = [600, 600]
+    num_classes     = 21
+    
+    device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    model   = FasterRCNN(num_classes, backbone = 'vgg').to(device)
+    summary(model, (3, input_shape[0], input_shape[1]))
+    
+    dummy_input     = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
+    flops, params   = profile(model.to(device), (dummy_input, ), verbose = False)
+    #--------------------------------------------------------#
+    #   flops * 2是因为profile没有将卷积作为两个operations
+    #   有些论文将卷积算乘法、加法两个operations。此时乘2
+    #   有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
+    #   本代码选择乘2,参考YOLOX。
+    #--------------------------------------------------------#
+    flops           = flops * 2
+    flops, params   = clever_format([flops, params], "%.3f")
+    print('Total GFLOPS: %s' % (flops))
+    print('Total params: %s' % (params))

+ 444 - 0
train.py

@@ -0,0 +1,444 @@
+#-------------------------------------#
+#       对数据集进行训练
+#-------------------------------------#
+import os
+import datetime
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+
+from nets.frcnn import FasterRCNN
+from nets.frcnn_training import (FasterRCNNTrainer, get_lr_scheduler,
+                                 set_optimizer_lr, weights_init)
+from utils.callbacks import EvalCallback, LossHistory
+from utils.dataloader import FRCNNDataset, frcnn_dataset_collate
+from utils.utils import get_classes, show_config
+from utils.utils_fit import fit_one_epoch
+
+
+'''
+训练自己的目标检测模型一定需要注意以下几点:
+1、训练前仔细检查自己的格式是否满足要求,该库要求数据集格式为VOC格式,需要准备好的内容有输入图片和标签
+   输入图片为.jpg图片,无需固定大小,传入训练前会自动进行resize。
+   灰度图会自动转成RGB图片进行训练,无需自己修改。
+   输入图片如果后缀非jpg,需要自己批量转成jpg后再开始训练。
+
+   标签为.xml格式,文件中会有需要检测的目标信息,标签文件和输入图片文件相对应。
+
+2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。
+   损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。
+   训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中
+   
+3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。
+   如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。
+'''
+
+
+if __name__ == "__main__":
+    #-------------------------------#
+    #   是否使用Cuda
+    #   没有GPU可以设置成False
+    #-------------------------------#
+    Cuda            = True
+    #---------------------------------------------------------------------#
+    #   train_gpu   训练用到的GPU
+    #               默认为第一张卡、双卡为[0, 1]、三卡为[0, 1, 2]
+    #               在使用多GPU时,每个卡上的batch为总batch除以卡的数量。
+    #---------------------------------------------------------------------#
+    train_gpu       = [0]
+    #---------------------------------------------------------------------#
+    #   fp16        是否使用混合精度训练
+    #               可减少约一半的显存、需要pytorch1.7.1以上
+    #---------------------------------------------------------------------#
+    fp16            = False
+    #---------------------------------------------------------------------#
+    #   classes_path    指向model_data下的txt,与自己训练的数据集相关 
+    #                   训练前一定要修改classes_path,使其对应自己的数据集
+    #---------------------------------------------------------------------#
+    classes_path    = '/root/autodl-tmp/faster-rcnn-pytorch-master/model_data/voc_classes.txt'
+    #----------------------------------------------------------------------------------------------------------------------------#
+    #   权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。
+    #   模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。
+    #   预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好
+    #
+    #   如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。
+    #   同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。
+    #   
+    #   当model_path = ''的时候不加载整个模型的权值。
+    #
+    #   此处使用的是整个模型的权重,因此是在train.py进行加载的,下面的pretrain不影响此处的权值加载。
+    #   如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',下面的pretrain = True,此时仅加载主干。
+    #   如果想要让模型从0开始训练,则设置model_path = '',下面的pretrain = Fasle,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
+    #   
+    #   一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练!
+    #   如果一定要从0开始,可以了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。
+    #----------------------------------------------------------------------------------------------------------------------------#
+    model_path      = '/root/autodl-tmp/faster-rcnn-pytorch-master/model_data/voc_weights_resnet.pth'
+    #------------------------------------------------------#
+    #   input_shape     输入的shape大小
+    #------------------------------------------------------#
+    input_shape     = [600, 600]
+    #---------------------------------------------#
+    #   vgg
+    #   resnet50
+    #---------------------------------------------#
+    backbone        = "resnet50"
+    #----------------------------------------------------------------------------------------------------------------------------#
+    #   pretrained      是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。
+    #                   如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。
+    #                   如果不设置model_path,pretrained = True,此时仅加载主干开始训练。
+    #                   如果不设置model_path,pretrained = False,Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
+    #----------------------------------------------------------------------------------------------------------------------------#
+    pretrained      = False
+    #------------------------------------------------------------------------#
+    #   anchors_size用于设定先验框的大小,每个特征点均存在9个先验框。
+    #   anchors_size每个数对应3个先验框。
+    #   当anchors_size = [8, 16, 32]的时候,生成的先验框宽高约为:
+    #   [90, 180] ; [180, 360]; [360, 720]; [128, 128]; 
+    #   [256, 256]; [512, 512]; [180, 90] ; [360, 180]; 
+    #   [720, 360]; 详情查看anchors.py
+    #   如果想要检测小物体,可以减小anchors_size靠前的数。
+    #   比如设置anchors_size = [4, 16, 32]
+    #------------------------------------------------------------------------#
+    anchors_size    = [8, 16, 32]
+
+    #----------------------------------------------------------------------------------------------------------------------------#
+    #   训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
+    #   冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,此时仅仅进行冻结训练。
+    #      
+    #   在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整:
+    #   (一)从整个模型的预训练权重开始训练: 
+    #       Adam:
+    #           Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4。(冻结)
+    #           Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4。(不冻结)
+    #       SGD:
+    #           Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 150,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2。(冻结)
+    #           Init_Epoch = 0,UnFreeze_Epoch = 150,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2。(不冻结)
+    #       其中:UnFreeze_Epoch可以在100-300之间调整。
+    #   (二)从主干网络的预训练权重开始训练:
+    #       Adam:
+    #           Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-4。(冻结)
+    #           Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-4。(不冻结)
+    #       SGD:
+    #           Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 150,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2。(冻结)
+    #           Init_Epoch = 0,UnFreeze_Epoch = 150,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2。(不冻结)
+    #       其中:由于从主干网络的预训练权重开始训练,主干的权值不一定适合目标检测,需要更多的训练跳出局部最优解。
+    #             UnFreeze_Epoch可以在150-300之间调整,YOLOV5和YOLOX均推荐使用300。
+    #             Adam相较于SGD收敛的快一些。因此UnFreeze_Epoch理论上可以小一点,但依然推荐更多的Epoch。
+    #   (三)batch_size的设置:
+    #       在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。
+    #       faster rcnn的Batch BatchNormalization层已经冻结,batch_size可以为1
+    #----------------------------------------------------------------------------------------------------------------------------#
+    #------------------------------------------------------------------#
+    #   冻结阶段训练参数
+    #   此时模型的主干被冻结了,特征提取网络不发生改变
+    #   占用的显存较小,仅对网络进行微调
+    #   Init_Epoch          模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置:
+    #                       Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100
+    #                       会跳过冻结阶段,直接从60代开始,并调整对应的学习率。
+    #                       (断点续练时使用)
+    #   Freeze_Epoch        模型冻结训练的Freeze_Epoch
+    #                       (当Freeze_Train=False时失效)
+    #   Freeze_batch_size   模型冻结训练的batch_size
+    #                       (当Freeze_Train=False时失效)
+    #------------------------------------------------------------------#
+    Init_Epoch          = 0
+    Freeze_Epoch        = 50
+    Freeze_batch_size   = 4
+    #------------------------------------------------------------------#
+    #   解冻阶段训练参数
+    #   此时模型的主干不被冻结了,特征提取网络会发生改变
+    #   占用的显存较大,网络所有的参数都会发生改变
+    #   UnFreeze_Epoch          模型总共训练的epoch
+    #                           SGD需要更长的时间收敛,因此设置较大的UnFreeze_Epoch
+    #                           Adam可以使用相对较小的UnFreeze_Epoch
+    #   Unfreeze_batch_size     模型在解冻后的batch_size
+    #------------------------------------------------------------------#
+    UnFreeze_Epoch      = 100
+    Unfreeze_batch_size = 2
+    #------------------------------------------------------------------#
+    #   Freeze_Train    是否进行冻结训练
+    #                   默认先冻结主干训练后解冻训练。
+    #                   如果设置Freeze_Train=False,建议使用优化器为sgd
+    #------------------------------------------------------------------#
+    Freeze_Train        = True
+    
+    #------------------------------------------------------------------#
+    #   其它训练参数:学习率、优化器、学习率下降有关
+    #------------------------------------------------------------------#
+    #------------------------------------------------------------------#
+    #   Init_lr         模型的最大学习率
+    #                   当使用Adam优化器时建议设置  Init_lr=1e-4
+    #                   当使用SGD优化器时建议设置   Init_lr=1e-2
+    #   Min_lr          模型的最小学习率,默认为最大学习率的0.01
+    #------------------------------------------------------------------#
+    Init_lr             = 1e-4
+    Min_lr              = Init_lr * 0.01
+    #------------------------------------------------------------------#
+    #   optimizer_type  使用到的优化器种类,可选的有adam、sgd
+    #                   当使用Adam优化器时建议设置  Init_lr=1e-4
+    #                   当使用SGD优化器时建议设置   Init_lr=1e-2
+    #   momentum        优化器内部使用到的momentum参数
+    #   weight_decay    权值衰减,可防止过拟合
+    #                   adam会导致weight_decay错误,使用adam时建议设置为0。
+    #------------------------------------------------------------------#
+    optimizer_type      = "adam"
+    momentum            = 0.9
+    weight_decay        = 0
+    #------------------------------------------------------------------#
+    #   lr_decay_type   使用到的学习率下降方式,可选的有'step'、'cos'
+    #------------------------------------------------------------------#
+    lr_decay_type       = 'cos'
+    #------------------------------------------------------------------#
+    #   save_period     多少个epoch保存一次权值
+    #------------------------------------------------------------------#
+    save_period         = 5
+    #------------------------------------------------------------------#
+    #   save_dir        权值与日志文件保存的文件夹
+    #------------------------------------------------------------------#
+    save_dir            = 'logs_wm'
+    #------------------------------------------------------------------#
+    #   eval_flag       是否在训练时进行评估,评估对象为验证集
+    #                   安装pycocotools库后,评估体验更佳。
+    #   eval_period     代表多少个epoch评估一次,不建议频繁的评估
+    #                   评估需要消耗较多的时间,频繁评估会导致训练非常慢
+    #   此处获得的mAP会与get_map.py获得的会有所不同,原因有二:
+    #   (一)此处获得的mAP为验证集的mAP。
+    #   (二)此处设置评估参数较为保守,目的是加快评估速度。
+    #------------------------------------------------------------------#
+    eval_flag           = True
+    eval_period         = 5
+    #------------------------------------------------------------------#
+    #   num_workers     用于设置是否使用多线程读取数据,1代表关闭多线程
+    #                   开启后会加快数据读取速度,但是会占用更多内存
+    #                   在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。
+    #------------------------------------------------------------------#
+    num_workers         = 4
+    #----------------------------------------------------#
+    #   获得图片路径和标签
+    #----------------------------------------------------#
+    train_annotation_path   = '2007_train_wm.txt'
+    val_annotation_path     = '2007_val_wm.txt'
+    
+    #----------------------------------------------------#
+    #   获取classes和anchor
+    #----------------------------------------------------#
+    class_names, num_classes = get_classes(classes_path)
+
+    #------------------------------------------------------#
+    #   设置用到的显卡
+    #------------------------------------------------------#
+    os.environ["CUDA_VISIBLE_DEVICES"]  = ','.join(str(x) for x in train_gpu)
+    ngpus_per_node                      = len(train_gpu)
+    print('Number of devices: {}'.format(ngpus_per_node))
+    
+    model = FasterRCNN(num_classes, anchor_scales = anchors_size, backbone = backbone, pretrained = pretrained)
+    if not pretrained:
+        weights_init(model)
+    if model_path != '':
+        #------------------------------------------------------#
+        #   权值文件请看README,百度网盘下载
+        #------------------------------------------------------#
+        print('Load weights {}.'.format(model_path))
+        
+        #------------------------------------------------------#
+        #   根据预训练权重的Key和模型的Key进行加载
+        #------------------------------------------------------#
+        device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        model_dict      = model.state_dict()
+        pretrained_dict = torch.load(model_path, map_location = device)
+        load_key, no_load_key, temp_dict = [], [], {}
+        for k, v in pretrained_dict.items():
+            if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
+                temp_dict[k] = v
+                load_key.append(k)
+            else:
+                no_load_key.append(k)
+        model_dict.update(temp_dict)
+        model.load_state_dict(model_dict)
+        #------------------------------------------------------#
+        #   显示没有匹配上的Key
+        #------------------------------------------------------#
+        print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
+        print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
+        print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")
+
+    #----------------------#
+    #   记录Loss
+    #----------------------#
+    time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
+    log_dir         = os.path.join(save_dir, "loss_" + str(time_str))
+    loss_history    = LossHistory(log_dir, model, input_shape = input_shape)
+
+    #------------------------------------------------------------------#
+    #   torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
+    #   因此torch1.2这里显示"could not be resolve"
+    #------------------------------------------------------------------#
+    if fp16:
+        from torch.cuda.amp import GradScaler as GradScaler
+        scaler = GradScaler()
+    else:
+        scaler = None
+
+    model_train     = model.train()
+    if Cuda:
+        model_train = torch.nn.DataParallel(model_train)
+        cudnn.benchmark = True
+        model_train = model_train.cuda()
+
+    #---------------------------#
+    #   读取数据集对应的txt
+    #---------------------------#
+    with open(train_annotation_path, encoding='utf-8') as f:
+        train_lines = f.readlines()
+    with open(val_annotation_path, encoding='utf-8') as f:
+        val_lines   = f.readlines()
+    num_train   = len(train_lines)
+    num_val     = len(val_lines)
+    
+    show_config(
+        classes_path = classes_path, model_path = model_path, input_shape = input_shape, \
+        Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \
+        Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
+        save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val
+    )
+    #---------------------------------------------------------#
+    #   总训练世代指的是遍历全部数据的总次数
+    #   总训练步长指的是梯度下降的总次数 
+    #   每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。
+    #   此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分
+    #----------------------------------------------------------#
+    wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
+    total_step  = num_train // Unfreeze_batch_size * UnFreeze_Epoch
+    if total_step <= wanted_step:
+        if num_train // Unfreeze_batch_size == 0:
+            raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
+        wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
+        print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m"%(optimizer_type, wanted_step))
+        print("\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
+        print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m"%(total_step, wanted_step, wanted_epoch))
+
+    #------------------------------------------------------#
+    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
+    #   也可以在训练初期防止权值被破坏。
+    #   Init_Epoch为起始世代
+    #   Freeze_Epoch为冻结训练的世代
+    #   UnFreeze_Epoch总训练世代
+    #   提示OOM或者显存不足请调小Batch_size
+    #------------------------------------------------------#
+    if True:
+        UnFreeze_flag = False
+        #------------------------------------#
+        #   冻结一定部分训练
+        #------------------------------------#
+        if Freeze_Train:
+            for param in model.extractor.parameters():
+                param.requires_grad = False
+        # ------------------------------------#
+        #   冻结bn层
+        # ------------------------------------#
+        model.freeze_bn()
+
+        #-------------------------------------------------------------------#
+        #   如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size
+        #-------------------------------------------------------------------#
+        batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
+
+        #-------------------------------------------------------------------#
+        #   判断当前batch_size,自适应调整学习率
+        #-------------------------------------------------------------------#
+        nbs             = 16
+        lr_limit_max    = 1e-4 if optimizer_type == 'adam' else 5e-2
+        lr_limit_min    = 1e-4 if optimizer_type == 'adam' else 5e-4
+        Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
+        Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
+        
+        #---------------------------------------#
+        #   根据optimizer_type选择优化器
+        #---------------------------------------#
+        optimizer = {
+            'adam'  : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay),
+            'sgd'   : optim.SGD(model.parameters(), Init_lr_fit, momentum = momentum, nesterov=True, weight_decay = weight_decay)
+        }[optimizer_type]
+
+        #---------------------------------------#
+        #   获得学习率下降的公式
+        #---------------------------------------#
+        lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
+        
+        #---------------------------------------#
+        #   判断每一个世代的长度
+        #---------------------------------------#
+        epoch_step      = num_train // batch_size
+        epoch_step_val  = num_val // batch_size
+
+        if epoch_step == 0 or epoch_step_val == 0:
+            raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
+
+        train_dataset   = FRCNNDataset(train_lines, input_shape, train = True)
+        val_dataset     = FRCNNDataset(val_lines, input_shape, train = False)
+
+        gen             = DataLoader(train_dataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory = True,
+                                    drop_last = True, collate_fn = frcnn_dataset_collate)
+        gen_val         = DataLoader(val_dataset  , shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory = True, 
+                                    drop_last = True, collate_fn = frcnn_dataset_collate)
+
+        train_util      = FasterRCNNTrainer(model_train, optimizer)
+        #----------------------#
+        #   记录eval的map曲线
+        #----------------------#
+        eval_callback   = EvalCallback(model_train, input_shape, class_names, num_classes, val_lines, log_dir, Cuda, \
+                                        eval_flag=eval_flag, period=eval_period)
+
+        #---------------------------------------#
+        #   开始模型训练
+        #---------------------------------------#
+        for epoch in range(Init_Epoch, UnFreeze_Epoch):
+            #---------------------------------------#
+            #   如果模型有冻结学习部分
+            #   则解冻,并设置参数
+            #---------------------------------------#
+            if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
+                batch_size = Unfreeze_batch_size
+
+                #-------------------------------------------------------------------#
+                #   判断当前batch_size,自适应调整学习率
+                #-------------------------------------------------------------------#
+                nbs             = 16
+                lr_limit_max    = 1e-4 if optimizer_type == 'adam' else 5e-2
+                lr_limit_min    = 1e-4 if optimizer_type == 'adam' else 5e-4
+                Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
+                Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
+                #---------------------------------------#
+                #   获得学习率下降的公式
+                #---------------------------------------#
+                lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
+                
+                for param in model.extractor.parameters():
+                    param.requires_grad = True
+                # ------------------------------------#
+                #   冻结bn层
+                # ------------------------------------#
+                model.freeze_bn()
+
+                epoch_step      = num_train // batch_size
+                epoch_step_val  = num_val // batch_size
+
+                if epoch_step == 0 or epoch_step_val == 0:
+                    raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
+
+                gen             = DataLoader(train_dataset, shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
+                                            drop_last=True, collate_fn=frcnn_dataset_collate)
+                gen_val         = DataLoader(val_dataset  , shuffle = True, batch_size = batch_size, num_workers = num_workers, pin_memory=True, 
+                                            drop_last=True, collate_fn=frcnn_dataset_collate)
+
+                UnFreeze_flag = True
+                
+            set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
+            
+            fit_one_epoch(model, train_util, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir)
+            
+        loss_history.writer.close()

+ 1 - 0
utils/__init__.py

@@ -0,0 +1 @@
+#

+ 67 - 0
utils/anchors.py

@@ -0,0 +1,67 @@
+import numpy as np
+
+#--------------------------------------------#
+#   生成基础的先验框
+#--------------------------------------------#
+def generate_anchor_base(base_size = 16, ratios = [0.5, 1, 2], anchor_scales = [8, 16, 32]):
+    anchor_base = np.zeros((len(ratios) * len(anchor_scales), 4), dtype = np.float32)
+    for i in range(len(ratios)):
+        for j in range(len(anchor_scales)):
+            h = base_size * anchor_scales[j] * np.sqrt(ratios[i])
+            w = base_size * anchor_scales[j] * np.sqrt(1. / ratios[i])
+
+            index = i * len(anchor_scales) + j
+            anchor_base[index, 0] = - h / 2.
+            anchor_base[index, 1] = - w / 2.
+            anchor_base[index, 2] = h / 2.
+            anchor_base[index, 3] = w / 2.
+    return anchor_base
+
+#--------------------------------------------#
+#   对基础先验框进行拓展对应到所有特征点上
+#--------------------------------------------#
+def _enumerate_shifted_anchor(anchor_base, feat_stride, height, width):
+    #---------------------------------#
+    #   计算网格中心点
+    #---------------------------------#
+    shift_x             = np.arange(0, width * feat_stride, feat_stride)
+    shift_y             = np.arange(0, height * feat_stride, feat_stride)
+    shift_x, shift_y    = np.meshgrid(shift_x, shift_y)
+    shift               = np.stack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel(),), axis=1)
+
+    #---------------------------------#
+    #   每个网格点上的9个先验框
+    #---------------------------------#
+    A       = anchor_base.shape[0]
+    K       = shift.shape[0]
+    anchor  = anchor_base.reshape((1, A, 4)) + shift.reshape((K, 1, 4))
+    #---------------------------------#
+    #   所有的先验框
+    #---------------------------------#
+    anchor  = anchor.reshape((K * A, 4)).astype(np.float32)
+    return anchor
+    
+if __name__ == "__main__":
+    import matplotlib.pyplot as plt
+    nine_anchors = generate_anchor_base()
+    print(nine_anchors)
+
+    height, width, feat_stride  = 38,38,16
+    anchors_all                 = _enumerate_shifted_anchor(nine_anchors, feat_stride, height, width)
+    print(np.shape(anchors_all))
+    
+    fig     = plt.figure()
+    ax      = fig.add_subplot(111)
+    plt.ylim(-300,900)
+    plt.xlim(-300,900)
+    shift_x = np.arange(0, width * feat_stride, feat_stride)
+    shift_y = np.arange(0, height * feat_stride, feat_stride)
+    shift_x, shift_y = np.meshgrid(shift_x, shift_y)
+    plt.scatter(shift_x,shift_y)
+    box_widths  = anchors_all[:,2]-anchors_all[:,0]
+    box_heights = anchors_all[:,3]-anchors_all[:,1]
+    
+    for i in [108, 109, 110, 111, 112, 113, 114, 115, 116]:
+        rect = plt.Rectangle([anchors_all[i, 0],anchors_all[i, 1]],box_widths[i],box_heights[i],color="r",fill=False)
+        ax.add_patch(rect)
+    plt.show()

+ 237 - 0
utils/callbacks.py

@@ -0,0 +1,237 @@
+import os
+
+import matplotlib
+import torch
+
+matplotlib.use('Agg')
+from matplotlib import pyplot as plt
+import scipy.signal
+
+import shutil
+import numpy as np
+from PIL import Image
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+
+from .utils import cvtColor, resize_image, preprocess_input, get_new_img_size
+from .utils_bbox import DecodeBox
+from .utils_map import get_coco_map, get_map
+
+class LossHistory():
+    def __init__(self, log_dir, model, input_shape):
+        self.log_dir    = log_dir
+        self.losses     = []
+        self.val_loss   = []
+        
+        os.makedirs(self.log_dir)
+        self.writer     = SummaryWriter(self.log_dir)
+        # try:
+        #     dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
+        #     self.writer.add_graph(model, dummy_input)
+        # except:
+        #     pass
+
+    def append_loss(self, epoch, loss, val_loss):
+        if not os.path.exists(self.log_dir):
+            os.makedirs(self.log_dir)
+
+        self.losses.append(loss)
+        self.val_loss.append(val_loss)
+
+        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
+            f.write(str(loss))
+            f.write("\n")
+        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
+            f.write(str(val_loss))
+            f.write("\n")
+
+        self.writer.add_scalar('loss', loss, epoch)
+        self.writer.add_scalar('val_loss', val_loss, epoch)
+        self.loss_plot()
+
+    def loss_plot(self):
+        iters = range(len(self.losses))
+
+        plt.figure()
+        plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
+        plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
+        try:
+            if len(self.losses) < 25:
+                num = 5
+            else:
+                num = 15
+            
+            plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
+            plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
+        except:
+            pass
+
+        plt.grid(True)
+        plt.xlabel('Epoch')
+        plt.ylabel('Loss')
+        plt.legend(loc="upper right")
+
+        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
+
+        plt.cla()
+        plt.close("all")
+
+class EvalCallback():
+    def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
+            map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
+        super(EvalCallback, self).__init__()
+        
+        self.net                = net
+        self.input_shape        = input_shape
+        self.class_names        = class_names
+        self.num_classes        = num_classes
+        self.val_lines          = val_lines
+        self.log_dir            = log_dir
+        self.cuda               = cuda
+        self.map_out_path       = map_out_path
+        self.max_boxes          = max_boxes
+        self.confidence         = confidence
+        self.nms_iou            = nms_iou
+        self.letterbox_image    = letterbox_image
+        self.MINOVERLAP         = MINOVERLAP
+        self.eval_flag          = eval_flag
+        self.period             = period
+        
+        self.std    = torch.Tensor([0.1, 0.1, 0.2, 0.2]).repeat(self.num_classes + 1)[None]
+        if self.cuda:
+            self.std    = self.std.cuda()
+        self.bbox_util  = DecodeBox(self.std, self.num_classes)
+
+        self.maps       = [0]
+        self.epoches    = [0]
+        if self.eval_flag:
+            with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
+                f.write(str(0))
+                f.write("\n")
+
+    #---------------------------------------------------#
+    #   检测图片
+    #---------------------------------------------------#
+    def get_map_txt(self, image_id, image, class_names, map_out_path):
+        f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
+        #---------------------------------------------------#
+        #   计算输入图片的高和宽
+        #---------------------------------------------------#
+        image_shape = np.array(np.shape(image)[0:2])
+        input_shape = get_new_img_size(image_shape[0], image_shape[1])
+        #---------------------------------------------------------#
+        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
+        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+        #---------------------------------------------------------#
+        image       = cvtColor(image)
+        
+        #---------------------------------------------------------#
+        #   给原图像进行resize,resize到短边为600的大小上
+        #---------------------------------------------------------#
+        image_data  = resize_image(image, [input_shape[1], input_shape[0]])
+        #---------------------------------------------------------#
+        #   添加上batch_size维度
+        #---------------------------------------------------------#
+        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
+
+        with torch.no_grad():
+            images = torch.from_numpy(image_data)
+            if self.cuda:
+                images = images.cuda()
+
+            roi_cls_locs, roi_scores, rois, _ = self.net(images)
+            #-------------------------------------------------------------#
+            #   利用classifier的预测结果对建议框进行解码,获得预测框
+            #-------------------------------------------------------------#
+            results = self.bbox_util.forward(roi_cls_locs, roi_scores, rois, image_shape, input_shape, 
+                                                    nms_iou = self.nms_iou, confidence = self.confidence)
+            #--------------------------------------#
+            #   如果没有检测到物体,则返回原图
+            #--------------------------------------#
+            if len(results[0]) <= 0:
+                return 
+
+            top_label   = np.array(results[0][:, 5], dtype = 'int32')
+            top_conf    = results[0][:, 4]
+            top_boxes   = results[0][:, :4]
+
+        top_100     = np.argsort(top_conf)[::-1][:self.max_boxes]
+        top_boxes   = top_boxes[top_100]
+        top_conf    = top_conf[top_100]
+        top_label   = top_label[top_100]
+
+        for i, c in list(enumerate(top_label)):
+            predicted_class = self.class_names[int(c)]
+            box             = top_boxes[i]
+            score           = str(top_conf[i])
+
+            top, left, bottom, right = box
+            if predicted_class not in class_names:
+                continue
+
+            f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
+
+        f.close()
+        return 
+    
+    def on_epoch_end(self, epoch):
+        if epoch % self.period == 0 and self.eval_flag:
+            if not os.path.exists(self.map_out_path):
+                os.makedirs(self.map_out_path)
+            if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
+                os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
+            if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
+                os.makedirs(os.path.join(self.map_out_path, "detection-results"))
+            print("Get map.")
+            for annotation_line in tqdm(self.val_lines):
+                line        = annotation_line.split()
+                image_id    = os.path.basename(line[0]).split('.')[0]
+                #------------------------------#
+                #   读取图像并转换成RGB图像
+                #------------------------------#
+                image       = Image.open(line[0])
+                #------------------------------#
+                #   获得预测框
+                #------------------------------#
+                gt_boxes    = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
+                #------------------------------#
+                #   获得预测txt
+                #------------------------------#
+                self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
+                
+                #------------------------------#
+                #   获得真实框txt
+                #------------------------------#
+                with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
+                    for box in gt_boxes:
+                        left, top, right, bottom, obj = box
+                        obj_name = self.class_names[obj]
+                        new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
+                        
+            print("Calculate Map.")
+            try:
+                temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
+            except:
+                temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
+            self.maps.append(temp_map)
+            self.epoches.append(epoch)
+
+            with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
+                f.write(str(temp_map))
+                f.write("\n")
+            
+            plt.figure()
+            plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
+
+            plt.grid(True)
+            plt.xlabel('Epoch')
+            plt.ylabel('Map %s'%str(self.MINOVERLAP))
+            plt.title('A Map Curve')
+            plt.legend(loc="upper right")
+
+            plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
+            plt.cla()
+            plt.close("all")
+
+            print("Get map done.")
+            shutil.rmtree(self.map_out_path)

+ 165 - 0
utils/dataloader.py

@@ -0,0 +1,165 @@
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+
+from utils.utils import cvtColor, preprocess_input
+
+
+class FRCNNDataset(Dataset):
+    def __init__(self, annotation_lines, input_shape = [600, 600], train = True):
+        self.annotation_lines   = annotation_lines
+        self.length             = len(annotation_lines)
+        self.input_shape        = input_shape
+        self.train              = train
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, index):
+        index       = index % self.length
+        #---------------------------------------------------#
+        #   训练时进行数据的随机增强
+        #   验证时不进行数据的随机增强
+        #---------------------------------------------------#
+        image, y    = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)
+        image       = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
+        box_data    = np.zeros((len(y), 5))
+        if len(y) > 0:
+            box_data[:len(y)] = y
+
+        box         = box_data[:, :4]
+        label       = box_data[:, -1]
+        return image, box, label
+
+    def rand(self, a=0, b=1):
+        return np.random.rand()*(b-a) + a
+
+    def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
+        line = annotation_line.split()
+        #------------------------------#
+        #   读取图像并转换成RGB图像
+        #------------------------------#
+        image   = Image.open(line[0])
+        image   = cvtColor(image)
+        #------------------------------#
+        #   获得图像的高宽与目标高宽
+        #------------------------------#
+        iw, ih  = image.size
+        h, w    = input_shape
+        #------------------------------#
+        #   获得预测框
+        #------------------------------#
+        box     = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
+
+        if not random:
+            scale = min(w/iw, h/ih)
+            nw = int(iw*scale)
+            nh = int(ih*scale)
+            dx = (w-nw)//2
+            dy = (h-nh)//2
+
+            #---------------------------------#
+            #   将图像多余的部分加上灰条
+            #---------------------------------#
+            image       = image.resize((nw,nh), Image.BICUBIC)
+            new_image   = Image.new('RGB', (w,h), (128,128,128))
+            new_image.paste(image, (dx, dy))
+            image_data  = np.array(new_image, np.float32)
+
+            #---------------------------------#
+            #   对真实框进行调整
+            #---------------------------------#
+            if len(box)>0:
+                np.random.shuffle(box)
+                box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
+                box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
+                box[:, 0:2][box[:, 0:2]<0] = 0
+                box[:, 2][box[:, 2]>w] = w
+                box[:, 3][box[:, 3]>h] = h
+                box_w = box[:, 2] - box[:, 0]
+                box_h = box[:, 3] - box[:, 1]
+                box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
+
+            return image_data, box
+                
+        #------------------------------------------#
+        #   对图像进行缩放并且进行长和宽的扭曲
+        #------------------------------------------#
+        new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
+        scale = self.rand(.25, 2)
+        if new_ar < 1:
+            nh = int(scale*h)
+            nw = int(nh*new_ar)
+        else:
+            nw = int(scale*w)
+            nh = int(nw/new_ar)
+        image = image.resize((nw,nh), Image.BICUBIC)
+
+        #------------------------------------------#
+        #   将图像多余的部分加上灰条
+        #------------------------------------------#
+        dx = int(self.rand(0, w-nw))
+        dy = int(self.rand(0, h-nh))
+        new_image = Image.new('RGB', (w,h), (128,128,128))
+        new_image.paste(image, (dx, dy))
+        image = new_image
+
+        #------------------------------------------#
+        #   翻转图像
+        #------------------------------------------#
+        flip = self.rand()<.5
+        if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
+
+        image_data      = np.array(image, np.uint8)
+        #---------------------------------#
+        #   对图像进行色域变换
+        #   计算色域变换的参数
+        #---------------------------------#
+        r               = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
+        #---------------------------------#
+        #   将图像转到HSV上
+        #---------------------------------#
+        hue, sat, val   = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
+        dtype           = image_data.dtype
+        #---------------------------------#
+        #   应用变换
+        #---------------------------------#
+        x       = np.arange(0, 256, dtype=r.dtype)
+        lut_hue = ((x * r[0]) % 180).astype(dtype)
+        lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
+        lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
+
+        image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
+        image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
+
+        #---------------------------------#
+        #   对真实框进行调整
+        #---------------------------------#
+        if len(box)>0:
+            np.random.shuffle(box)
+            box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
+            box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
+            if flip: box[:, [0,2]] = w - box[:, [2,0]]
+            box[:, 0:2][box[:, 0:2]<0] = 0
+            box[:, 2][box[:, 2]>w] = w
+            box[:, 3][box[:, 3]>h] = h
+            box_w = box[:, 2] - box[:, 0]
+            box_h = box[:, 3] - box[:, 1]
+            box = box[np.logical_and(box_w>1, box_h>1)] 
+        
+        return image_data, box
+
+# DataLoader中collate_fn使用
+def frcnn_dataset_collate(batch):
+    images = []
+    bboxes = []
+    labels = []
+    for img, box, label in batch:
+        images.append(img)
+        bboxes.append(box)
+        labels.append(label)
+    images = torch.from_numpy(np.array(images))
+    return images, bboxes, labels
+

+ 62 - 0
utils/utils.py

@@ -0,0 +1,62 @@
+import numpy as np
+from PIL import Image
+
+#---------------------------------------------------------#
+#   将图像转换成RGB图像,防止灰度图在预测时报错。
+#   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
+#---------------------------------------------------------#
+def cvtColor(image):
+    if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
+        return image 
+    else:
+        image = image.convert('RGB')
+        return image 
+
+#---------------------------------------------------#
+#   对输入图像进行resize
+#---------------------------------------------------#
+def resize_image(image, size):
+    w, h        = size
+    new_image   = image.resize((w, h), Image.BICUBIC)
+    return new_image
+
+#---------------------------------------------------#
+#   获得类
+#---------------------------------------------------#
+def get_classes(classes_path):
+    with open(classes_path, encoding='utf-8') as f:
+        class_names = f.readlines()
+    class_names = [c.strip() for c in class_names]
+    return class_names, len(class_names)
+
+#---------------------------------------------------#
+#   获得学习率
+#---------------------------------------------------#
+def get_lr(optimizer):
+    for param_group in optimizer.param_groups:
+        return param_group['lr']
+
+def preprocess_input(image):
+    image /= 255.0
+    return image
+
+def show_config(**kwargs):
+    print('Configurations:')
+    print('-' * 70)
+    print('|%25s | %40s|' % ('keys', 'values'))
+    print('-' * 70)
+    for key, value in kwargs.items():
+        print('|%25s | %40s|' % (str(key), str(value)))
+    print('-' * 70)
+
+def get_new_img_size(height, width, img_min_side=600):
+    if width <= height:
+        f = float(img_min_side) / width
+        resized_height = int(f * height)
+        resized_width = int(img_min_side)
+    else:
+        f = float(img_min_side) / height
+        resized_width = int(f * width)
+        resized_height = int(img_min_side)
+
+    return resized_height, resized_width

+ 131 - 0
utils/utils_bbox.py

@@ -0,0 +1,131 @@
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.ops import nms
+
+
+def loc2bbox(src_bbox, loc):
+    if src_bbox.size()[0] == 0:
+        return torch.zeros((0, 4), dtype=loc.dtype)
+
+    src_width   = torch.unsqueeze(src_bbox[:, 2] - src_bbox[:, 0], -1)
+    src_height  = torch.unsqueeze(src_bbox[:, 3] - src_bbox[:, 1], -1)
+    src_ctr_x   = torch.unsqueeze(src_bbox[:, 0], -1) + 0.5 * src_width
+    src_ctr_y   = torch.unsqueeze(src_bbox[:, 1], -1) + 0.5 * src_height
+
+    dx          = loc[:, 0::4]
+    dy          = loc[:, 1::4]
+    dw          = loc[:, 2::4]
+    dh          = loc[:, 3::4]
+
+    ctr_x = dx * src_width + src_ctr_x
+    ctr_y = dy * src_height + src_ctr_y
+    w = torch.exp(dw) * src_width
+    h = torch.exp(dh) * src_height
+
+    dst_bbox = torch.zeros_like(loc)
+    dst_bbox[:, 0::4] = ctr_x - 0.5 * w
+    dst_bbox[:, 1::4] = ctr_y - 0.5 * h
+    dst_bbox[:, 2::4] = ctr_x + 0.5 * w
+    dst_bbox[:, 3::4] = ctr_y + 0.5 * h
+
+    return dst_bbox
+
+class DecodeBox():
+    def __init__(self, std, num_classes):
+        self.std            = std
+        self.num_classes    = num_classes + 1    
+
+    def frcnn_correct_boxes(self, box_xy, box_wh, input_shape, image_shape):
+        #-----------------------------------------------------------------#
+        #   把y轴放前面是因为方便预测框和图像的宽高进行相乘
+        #-----------------------------------------------------------------#
+        box_yx = box_xy[..., ::-1]
+        box_hw = box_wh[..., ::-1]
+        input_shape = np.array(input_shape)
+        image_shape = np.array(image_shape)
+
+        box_mins    = box_yx - (box_hw / 2.)
+        box_maxes   = box_yx + (box_hw / 2.)
+        boxes  = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
+        boxes *= np.concatenate([image_shape, image_shape], axis=-1)
+        return boxes
+
+    def forward(self, roi_cls_locs, roi_scores, rois, image_shape, input_shape, nms_iou = 0.3, confidence = 0.5):
+        results = []
+        bs      = len(roi_cls_locs)
+        #--------------------------------#
+        #   batch_size, num_rois, 4
+        #--------------------------------#
+        rois    = rois.view((bs, -1, 4))
+        #----------------------------------------------------------------------------------------------------------------#
+        #   对每一张图片进行处理,由于在predict.py的时候,我们只输入一张图片,所以for i in range(len(mbox_loc))只进行一次
+        #----------------------------------------------------------------------------------------------------------------#
+        for i in range(bs):
+            #----------------------------------------------------------#
+            #   对回归参数进行reshape
+            #----------------------------------------------------------#
+            roi_cls_loc = roi_cls_locs[i] * self.std
+            #----------------------------------------------------------#
+            #   第一维度是建议框的数量,第二维度是每个种类
+            #   第三维度是对应种类的调整参数
+            #----------------------------------------------------------#
+            roi_cls_loc = roi_cls_loc.view([-1, self.num_classes, 4])
+
+            #-------------------------------------------------------------#
+            #   利用classifier网络的预测结果对建议框进行调整获得预测框
+            #   num_rois, 4 -> num_rois, 1, 4 -> num_rois, num_classes, 4
+            #-------------------------------------------------------------#
+            roi         = rois[i].view((-1, 1, 4)).expand_as(roi_cls_loc)
+            cls_bbox    = loc2bbox(roi.contiguous().view((-1, 4)), roi_cls_loc.contiguous().view((-1, 4)))
+            cls_bbox    = cls_bbox.view([-1, (self.num_classes), 4])
+            #-------------------------------------------------------------#
+            #   对预测框进行归一化,调整到0-1之间
+            #-------------------------------------------------------------#
+            cls_bbox[..., [0, 2]] = (cls_bbox[..., [0, 2]]) / input_shape[1]
+            cls_bbox[..., [1, 3]] = (cls_bbox[..., [1, 3]]) / input_shape[0]
+
+            roi_score   = roi_scores[i]
+            prob        = F.softmax(roi_score, dim=-1)
+
+            results.append([])
+            for c in range(1, self.num_classes):
+                #--------------------------------#
+                #   取出属于该类的所有框的置信度
+                #   判断是否大于门限
+                #--------------------------------#
+                c_confs     = prob[:, c]
+                c_confs_m   = c_confs > confidence
+
+                if len(c_confs[c_confs_m]) > 0:
+                    #-----------------------------------------#
+                    #   取出得分高于confidence的框
+                    #-----------------------------------------#
+                    boxes_to_process = cls_bbox[c_confs_m, c]
+                    confs_to_process = c_confs[c_confs_m]
+
+                    keep = nms(
+                        boxes_to_process,
+                        confs_to_process,
+                        nms_iou
+                    )
+                    #-----------------------------------------#
+                    #   取出在非极大抑制中效果较好的内容
+                    #-----------------------------------------#
+                    good_boxes  = boxes_to_process[keep]
+                    confs       = confs_to_process[keep][:, None]
+                    labels      = (c - 1) * torch.ones((len(keep), 1)).cuda() if confs.is_cuda else (c - 1) * torch.ones((len(keep), 1))
+                    #-----------------------------------------#
+                    #   将label、置信度、框的位置进行堆叠。
+                    #-----------------------------------------#
+                    c_pred      = torch.cat((good_boxes, confs, labels), dim=1).cpu().numpy()
+                    # 添加进result里
+                    results[-1].extend(c_pred)
+
+            if len(results[-1]) > 0:
+                results[-1] = np.array(results[-1])
+                box_xy, box_wh = (results[-1][:, 0:2] + results[-1][:, 2:4])/2, results[-1][:, 2:4] - results[-1][:, 0:2]
+                results[-1][:, :4] = self.frcnn_correct_boxes(box_xy, box_wh, input_shape, image_shape)
+
+        return results
+        

+ 76 - 0
utils/utils_fit.py

@@ -0,0 +1,76 @@
+import os
+
+import torch
+from tqdm import tqdm
+
+from utils.utils import get_lr
+
+
+def fit_one_epoch(model, train_util, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir):
+    total_loss = 0
+    rpn_loc_loss = 0
+    rpn_cls_loss = 0
+    roi_loc_loss = 0
+    roi_cls_loss = 0
+    
+    val_loss = 0
+    print('Start Train')
+    with tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
+        for iteration, batch in enumerate(gen):
+            if iteration >= epoch_step:
+                break
+            images, boxes, labels = batch[0], batch[1], batch[2]
+            with torch.no_grad():
+                if cuda:
+                    images = images.cuda()
+
+            rpn_loc, rpn_cls, roi_loc, roi_cls, total = train_util.train_step(images, boxes, labels, 1, fp16, scaler)
+            total_loss      += total.item()
+            rpn_loc_loss    += rpn_loc.item()
+            rpn_cls_loss    += rpn_cls.item()
+            roi_loc_loss    += roi_loc.item()
+            roi_cls_loss    += roi_cls.item()
+            
+            pbar.set_postfix(**{'total_loss'    : total_loss / (iteration + 1), 
+                                'rpn_loc'       : rpn_loc_loss / (iteration + 1),  
+                                'rpn_cls'       : rpn_cls_loss / (iteration + 1), 
+                                'roi_loc'       : roi_loc_loss / (iteration + 1), 
+                                'roi_cls'       : roi_cls_loss / (iteration + 1), 
+                                'lr'            : get_lr(optimizer)})
+            pbar.update(1)
+
+    print('Finish Train')
+    print('Start Validation')
+    with tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
+        for iteration, batch in enumerate(gen_val):
+            if iteration >= epoch_step_val:
+                break
+            images, boxes, labels = batch[0], batch[1], batch[2]
+            with torch.no_grad():
+                if cuda:
+                    images = images.cuda()
+
+                train_util.optimizer.zero_grad()
+                _, _, _, _, val_total = train_util.forward(images, boxes, labels, 1)
+                val_loss += val_total.item()
+                
+                pbar.set_postfix(**{'val_loss'  : val_loss / (iteration + 1)})
+                pbar.update(1)
+
+    print('Finish Validation')
+    loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
+    eval_callback.on_epoch_end(epoch + 1)
+    print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
+    print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
+    
+    #-----------------------------------------------#
+    #   保存权值
+    #-----------------------------------------------#
+    if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
+        torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
+
+    if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
+        print('Save best model to best_epoch_weights.pth')
+        torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
+            
+    torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))

+ 923 - 0
utils/utils_map.py

@@ -0,0 +1,923 @@
+import glob
+import json
+import math
+import operator
+import os
+import shutil
+import sys
+try:
+    from pycocotools.coco import COCO
+    from pycocotools.cocoeval import COCOeval
+except:
+    pass
+import cv2
+import matplotlib
+matplotlib.use('Agg')
+from matplotlib import pyplot as plt
+import numpy as np
+
+'''
+    0,0 ------> x (width)
+     |
+     |  (Left,Top)
+     |      *_________
+     |      |         |
+            |         |
+     y      |_________|
+  (height)            *
+                (Right,Bottom)
+'''
+
+def log_average_miss_rate(precision, fp_cumsum, num_images):
+    """
+        log-average miss rate:
+            Calculated by averaging miss rates at 9 evenly spaced FPPI points
+            between 10e-2 and 10e0, in log-space.
+
+        output:
+                lamr | log-average miss rate
+                mr | miss rate
+                fppi | false positives per image
+
+        references:
+            [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
+               State of the Art." Pattern Analysis and Machine Intelligence, IEEE
+               Transactions on 34.4 (2012): 743 - 761.
+    """
+
+    if precision.size == 0:
+        lamr = 0
+        mr = 1
+        fppi = 0
+        return lamr, mr, fppi
+
+    fppi = fp_cumsum / float(num_images)
+    mr = (1 - precision)
+
+    fppi_tmp = np.insert(fppi, 0, -1.0)
+    mr_tmp = np.insert(mr, 0, 1.0)
+
+    ref = np.logspace(-2.0, 0.0, num = 9)
+    for i, ref_i in enumerate(ref):
+        j = np.where(fppi_tmp <= ref_i)[-1][-1]
+        ref[i] = mr_tmp[j]
+
+    lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
+
+    return lamr, mr, fppi
+
+"""
+ throw error and exit
+"""
+def error(msg):
+    print(msg)
+    sys.exit(0)
+
+"""
+ check if the number is a float between 0.0 and 1.0
+"""
+def is_float_between_0_and_1(value):
+    try:
+        val = float(value)
+        if val > 0.0 and val < 1.0:
+            return True
+        else:
+            return False
+    except ValueError:
+        return False
+
+"""
+ Calculate the AP given the recall and precision array
+    1st) We compute a version of the measured precision/recall curve with
+         precision monotonically decreasing
+    2nd) We compute the AP as the area under this curve by numerical integration.
+"""
+def voc_ap(rec, prec):
+    """
+    --- Official matlab code VOC2012---
+    mrec=[0 ; rec ; 1];
+    mpre=[0 ; prec ; 0];
+    for i=numel(mpre)-1:-1:1
+            mpre(i)=max(mpre(i),mpre(i+1));
+    end
+    i=find(mrec(2:end)~=mrec(1:end-1))+1;
+    ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
+    """
+    rec.insert(0, 0.0) # insert 0.0 at begining of list
+    rec.append(1.0) # insert 1.0 at end of list
+    mrec = rec[:]
+    prec.insert(0, 0.0) # insert 0.0 at begining of list
+    prec.append(0.0) # insert 0.0 at end of list
+    mpre = prec[:]
+    """
+     This part makes the precision monotonically decreasing
+        (goes from the end to the beginning)
+        matlab: for i=numel(mpre)-1:-1:1
+                    mpre(i)=max(mpre(i),mpre(i+1));
+    """
+    for i in range(len(mpre)-2, -1, -1):
+        mpre[i] = max(mpre[i], mpre[i+1])
+    """
+     This part creates a list of indexes where the recall changes
+        matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
+    """
+    i_list = []
+    for i in range(1, len(mrec)):
+        if mrec[i] != mrec[i-1]:
+            i_list.append(i) # if it was matlab would be i + 1
+    """
+     The Average Precision (AP) is the area under the curve
+        (numerical integration)
+        matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
+    """
+    ap = 0.0
+    for i in i_list:
+        ap += ((mrec[i]-mrec[i-1])*mpre[i])
+    return ap, mrec, mpre
+
+
+"""
+ Convert the lines of a file to a list
+"""
+def file_lines_to_list(path):
+    # open txt file lines to a list
+    with open(path) as f:
+        content = f.readlines()
+    # remove whitespace characters like `\n` at the end of each line
+    content = [x.strip() for x in content]
+    return content
+
+"""
+ Draws text in image
+"""
+def draw_text_in_image(img, text, pos, color, line_width):
+    font = cv2.FONT_HERSHEY_PLAIN
+    fontScale = 1
+    lineType = 1
+    bottomLeftCornerOfText = pos
+    cv2.putText(img, text,
+            bottomLeftCornerOfText,
+            font,
+            fontScale,
+            color,
+            lineType)
+    text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
+    return img, (line_width + text_width)
+
+"""
+ Plot - adjust axes
+"""
+def adjust_axes(r, t, fig, axes):
+    # get text width for re-scaling
+    bb = t.get_window_extent(renderer=r)
+    text_width_inches = bb.width / fig.dpi
+    # get axis width in inches
+    current_fig_width = fig.get_figwidth()
+    new_fig_width = current_fig_width + text_width_inches
+    propotion = new_fig_width / current_fig_width
+    # get axis limit
+    x_lim = axes.get_xlim()
+    axes.set_xlim([x_lim[0], x_lim[1]*propotion])
+
+"""
+ Draw plot using Matplotlib
+"""
+def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
+    # sort the dictionary by decreasing value, into a list of tuples
+    sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
+    # unpacking the list of tuples into two lists
+    sorted_keys, sorted_values = zip(*sorted_dic_by_value)
+    # 
+    if true_p_bar != "":
+        """
+         Special case to draw in:
+            - green -> TP: True Positives (object detected and matches ground-truth)
+            - red -> FP: False Positives (object detected but does not match ground-truth)
+            - orange -> FN: False Negatives (object not detected but present in the ground-truth)
+        """
+        fp_sorted = []
+        tp_sorted = []
+        for key in sorted_keys:
+            fp_sorted.append(dictionary[key] - true_p_bar[key])
+            tp_sorted.append(true_p_bar[key])
+        plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
+        plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
+        # add legend
+        plt.legend(loc='lower right')
+        """
+         Write number on side of bar
+        """
+        fig = plt.gcf() # gcf - get current figure
+        axes = plt.gca()
+        r = fig.canvas.get_renderer()
+        for i, val in enumerate(sorted_values):
+            fp_val = fp_sorted[i]
+            tp_val = tp_sorted[i]
+            fp_str_val = " " + str(fp_val)
+            tp_str_val = fp_str_val + " " + str(tp_val)
+            # trick to paint multicolor with offset:
+            # first paint everything and then repaint the first number
+            t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
+            plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
+            if i == (len(sorted_values)-1): # largest bar
+                adjust_axes(r, t, fig, axes)
+    else:
+        plt.barh(range(n_classes), sorted_values, color=plot_color)
+        """
+         Write number on side of bar
+        """
+        fig = plt.gcf() # gcf - get current figure
+        axes = plt.gca()
+        r = fig.canvas.get_renderer()
+        for i, val in enumerate(sorted_values):
+            str_val = " " + str(val) # add a space before
+            if val < 1.0:
+                str_val = " {0:.2f}".format(val)
+            t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
+            # re-set axes to show number inside the figure
+            if i == (len(sorted_values)-1): # largest bar
+                adjust_axes(r, t, fig, axes)
+    # set window title
+    fig.canvas.manager.set_window_title(window_title)
+    # write classes in y axis
+    tick_font_size = 12
+    plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
+    """
+     Re-scale height accordingly
+    """
+    init_height = fig.get_figheight()
+    # comput the matrix height in points and inches
+    dpi = fig.dpi
+    height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
+    height_in = height_pt / dpi
+    # compute the required figure height 
+    top_margin = 0.15 # in percentage of the figure height
+    bottom_margin = 0.05 # in percentage of the figure height
+    figure_height = height_in / (1 - top_margin - bottom_margin)
+    # set new height
+    if figure_height > init_height:
+        fig.set_figheight(figure_height)
+
+    # set plot title
+    plt.title(plot_title, fontsize=14)
+    # set axis titles
+    # plt.xlabel('classes')
+    plt.xlabel(x_label, fontsize='large')
+    # adjust size of window
+    fig.tight_layout()
+    # save the plot
+    fig.savefig(output_path)
+    # show image
+    if to_show:
+        plt.show()
+    # close the plot
+    plt.close()
+
+def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'):
+    GT_PATH             = os.path.join(path, 'ground-truth')
+    DR_PATH             = os.path.join(path, 'detection-results')
+    IMG_PATH            = os.path.join(path, 'images-optional')
+    TEMP_FILES_PATH     = os.path.join(path, '.temp_files')
+    RESULTS_FILES_PATH  = os.path.join(path, 'results')
+
+    show_animation = True
+    if os.path.exists(IMG_PATH): 
+        for dirpath, dirnames, files in os.walk(IMG_PATH):
+            if not files:
+                show_animation = False
+    else:
+        show_animation = False
+
+    if not os.path.exists(TEMP_FILES_PATH):
+        os.makedirs(TEMP_FILES_PATH)
+        
+    if os.path.exists(RESULTS_FILES_PATH):
+        shutil.rmtree(RESULTS_FILES_PATH)
+    else:
+        os.makedirs(RESULTS_FILES_PATH)
+    if draw_plot:
+        try:
+            matplotlib.use('TkAgg')
+        except:
+            pass
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
+    if show_animation:
+        os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
+
+    ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
+    if len(ground_truth_files_list) == 0:
+        error("Error: No ground-truth files found!")
+    ground_truth_files_list.sort()
+    gt_counter_per_class     = {}
+    counter_images_per_class = {}
+
+    for txt_file in ground_truth_files_list:
+        file_id     = txt_file.split(".txt", 1)[0]
+        file_id     = os.path.basename(os.path.normpath(file_id))
+        temp_path   = os.path.join(DR_PATH, (file_id + ".txt"))
+        if not os.path.exists(temp_path):
+            error_msg = "Error. File not found: {}\n".format(temp_path)
+            error(error_msg)
+        lines_list      = file_lines_to_list(txt_file)
+        bounding_boxes  = []
+        is_difficult    = False
+        already_seen_classes = []
+        for line in lines_list:
+            try:
+                if "difficult" in line:
+                    class_name, left, top, right, bottom, _difficult = line.split()
+                    is_difficult = True
+                else:
+                    class_name, left, top, right, bottom = line.split()
+            except:
+                if "difficult" in line:
+                    line_split  = line.split()
+                    _difficult  = line_split[-1]
+                    bottom      = line_split[-2]
+                    right       = line_split[-3]
+                    top         = line_split[-4]
+                    left        = line_split[-5]
+                    class_name  = ""
+                    for name in line_split[:-5]:
+                        class_name += name + " "
+                    class_name  = class_name[:-1]
+                    is_difficult = True
+                else:
+                    line_split  = line.split()
+                    bottom      = line_split[-1]
+                    right       = line_split[-2]
+                    top         = line_split[-3]
+                    left        = line_split[-4]
+                    class_name  = ""
+                    for name in line_split[:-4]:
+                        class_name += name + " "
+                    class_name = class_name[:-1]
+
+            bbox = left + " " + top + " " + right + " " + bottom
+            if is_difficult:
+                bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
+                is_difficult = False
+            else:
+                bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
+                if class_name in gt_counter_per_class:
+                    gt_counter_per_class[class_name] += 1
+                else:
+                    gt_counter_per_class[class_name] = 1
+
+                if class_name not in already_seen_classes:
+                    if class_name in counter_images_per_class:
+                        counter_images_per_class[class_name] += 1
+                    else:
+                        counter_images_per_class[class_name] = 1
+                    already_seen_classes.append(class_name)
+
+        with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
+            json.dump(bounding_boxes, outfile)
+
+    gt_classes  = list(gt_counter_per_class.keys())
+    gt_classes  = sorted(gt_classes)
+    n_classes   = len(gt_classes)
+
+    dr_files_list = glob.glob(DR_PATH + '/*.txt')
+    dr_files_list.sort()
+    for class_index, class_name in enumerate(gt_classes):
+        bounding_boxes = []
+        for txt_file in dr_files_list:
+            file_id = txt_file.split(".txt",1)[0]
+            file_id = os.path.basename(os.path.normpath(file_id))
+            temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
+            if class_index == 0:
+                if not os.path.exists(temp_path):
+                    error_msg = "Error. File not found: {}\n".format(temp_path)
+                    error(error_msg)
+            lines = file_lines_to_list(txt_file)
+            for line in lines:
+                try:
+                    tmp_class_name, confidence, left, top, right, bottom = line.split()
+                except:
+                    line_split      = line.split()
+                    bottom          = line_split[-1]
+                    right           = line_split[-2]
+                    top             = line_split[-3]
+                    left            = line_split[-4]
+                    confidence      = line_split[-5]
+                    tmp_class_name  = ""
+                    for name in line_split[:-5]:
+                        tmp_class_name += name + " "
+                    tmp_class_name  = tmp_class_name[:-1]
+
+                if tmp_class_name == class_name:
+                    bbox = left + " " + top + " " + right + " " +bottom
+                    bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
+
+        bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
+        with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
+            json.dump(bounding_boxes, outfile)
+
+    sum_AP = 0.0
+    ap_dictionary = {}
+    lamr_dictionary = {}
+    with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
+        results_file.write("# AP and precision/recall per class\n")
+        count_true_positives = {}
+
+        for class_index, class_name in enumerate(gt_classes):
+            count_true_positives[class_name] = 0
+            dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
+            dr_data = json.load(open(dr_file))
+
+            nd          = len(dr_data)
+            tp          = [0] * nd
+            fp          = [0] * nd
+            score       = [0] * nd
+            score_threhold_idx = 0
+            for idx, detection in enumerate(dr_data):
+                file_id     = detection["file_id"]
+                score[idx]  = float(detection["confidence"])
+                if score[idx] >= score_threhold:
+                    score_threhold_idx = idx
+
+                if show_animation:
+                    ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
+                    if len(ground_truth_img) == 0:
+                        error("Error. Image not found with id: " + file_id)
+                    elif len(ground_truth_img) > 1:
+                        error("Error. Multiple image with id: " + file_id)
+                    else:
+                        img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
+                        img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
+                        if os.path.isfile(img_cumulative_path):
+                            img_cumulative = cv2.imread(img_cumulative_path)
+                        else:
+                            img_cumulative = img.copy()
+                        bottom_border = 60
+                        BLACK = [0, 0, 0]
+                        img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
+
+                gt_file             = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
+                ground_truth_data   = json.load(open(gt_file))
+                ovmax       = -1
+                gt_match    = -1
+                bb          = [float(x) for x in detection["bbox"].split()]
+                for obj in ground_truth_data:
+                    if obj["class_name"] == class_name:
+                        bbgt    = [ float(x) for x in obj["bbox"].split() ]
+                        bi      = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
+                        iw      = bi[2] - bi[0] + 1
+                        ih      = bi[3] - bi[1] + 1
+                        if iw > 0 and ih > 0:
+                            ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
+                                            + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
+                            ov = iw * ih / ua
+                            if ov > ovmax:
+                                ovmax = ov
+                                gt_match = obj
+
+                if show_animation:
+                    status = "NO MATCH FOUND!" 
+                    
+                min_overlap = MINOVERLAP
+                if ovmax >= min_overlap:
+                    if "difficult" not in gt_match:
+                        if not bool(gt_match["used"]):
+                            tp[idx] = 1
+                            gt_match["used"] = True
+                            count_true_positives[class_name] += 1
+                            with open(gt_file, 'w') as f:
+                                    f.write(json.dumps(ground_truth_data))
+                            if show_animation:
+                                status = "MATCH!"
+                        else:
+                            fp[idx] = 1
+                            if show_animation:
+                                status = "REPEATED MATCH!"
+                else:
+                    fp[idx] = 1
+                    if ovmax > 0:
+                        status = "INSUFFICIENT OVERLAP"
+
+                """
+                Draw image to show animation
+                """
+                if show_animation:
+                    height, widht = img.shape[:2]
+                    white           = (255,255,255)
+                    light_blue      = (255,200,100)
+                    green           = (0,255,0)
+                    light_red       = (30,30,255)
+                    margin          = 10
+                    # 1nd line
+                    v_pos           = int(height - margin - (bottom_border / 2.0))
+                    text            = "Image: " + ground_truth_img[0] + " "
+                    img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
+                    text            = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
+                    img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
+                    if ovmax != -1:
+                        color       = light_red
+                        if status   == "INSUFFICIENT OVERLAP":
+                            text    = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
+                        else:
+                            text    = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
+                            color   = green
+                        img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
+                    # 2nd line
+                    v_pos           += int(bottom_border / 2.0)
+                    rank_pos        = str(idx+1)
+                    text            = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
+                    img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
+                    color           = light_red
+                    if status == "MATCH!":
+                        color = green
+                    text            = "Result: " + status + " "
+                    img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
+
+                    font = cv2.FONT_HERSHEY_SIMPLEX
+                    if ovmax > 0: 
+                        bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
+                        cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
+                        cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
+                        cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
+                    bb = [int(i) for i in bb]
+                    cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
+                    cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
+                    cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
+
+                    cv2.imshow("Animation", img)
+                    cv2.waitKey(20) 
+                    output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
+                    cv2.imwrite(output_img_path, img)
+                    cv2.imwrite(img_cumulative_path, img_cumulative)
+
+            cumsum = 0
+            for idx, val in enumerate(fp):
+                fp[idx] += cumsum
+                cumsum += val
+                
+            cumsum = 0
+            for idx, val in enumerate(tp):
+                tp[idx] += cumsum
+                cumsum += val
+
+            rec = tp[:]
+            for idx, val in enumerate(tp):
+                rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
+
+            prec = tp[:]
+            for idx, val in enumerate(tp):
+                prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
+
+            ap, mrec, mprec = voc_ap(rec[:], prec[:])
+            F1  = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
+
+            sum_AP  += ap
+            text    = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
+
+            if len(prec)>0:
+                F1_text         = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 "
+                Recall_text     = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall "
+                Precision_text  = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision "
+            else:
+                F1_text         = "0.00" + " = " + class_name + " F1 " 
+                Recall_text     = "0.00%" + " = " + class_name + " Recall " 
+                Precision_text  = "0.00%" + " = " + class_name + " Precision " 
+
+            rounded_prec    = [ '%.2f' % elem for elem in prec ]
+            rounded_rec     = [ '%.2f' % elem for elem in rec ]
+            results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
+            
+            if len(prec)>0:
+                print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\
+                    + " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100))
+            else:
+                print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%")
+            ap_dictionary[class_name] = ap
+
+            n_images = counter_images_per_class[class_name]
+            lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
+            lamr_dictionary[class_name] = lamr
+
+            if draw_plot:
+                plt.plot(rec, prec, '-o')
+                area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
+                area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
+                plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
+
+                fig = plt.gcf()
+                fig.canvas.manager.set_window_title('AP ' + class_name)
+
+                plt.title('class: ' + text)
+                plt.xlabel('Recall')
+                plt.ylabel('Precision')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05]) 
+                fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
+                plt.cla()
+
+                plt.plot(score, F1, "-", color='orangered')
+                plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold))
+                plt.xlabel('Score_Threhold')
+                plt.ylabel('F1')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05])
+                fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
+                plt.cla()
+
+                plt.plot(score, rec, "-H", color='gold')
+                plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold))
+                plt.xlabel('Score_Threhold')
+                plt.ylabel('Recall')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05])
+                fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
+                plt.cla()
+
+                plt.plot(score, prec, "-s", color='palevioletred')
+                plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold))
+                plt.xlabel('Score_Threhold')
+                plt.ylabel('Precision')
+                axes = plt.gca()
+                axes.set_xlim([0.0,1.0])
+                axes.set_ylim([0.0,1.05])
+                fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
+                plt.cla()
+                
+        if show_animation:
+            cv2.destroyAllWindows()
+        if n_classes == 0:
+            print("未检测到任何种类,请检查标签信息与get_map.py中的classes_path是否修改。")
+            return 0
+        results_file.write("\n# mAP of all classes\n")
+        mAP     = sum_AP / n_classes
+        text    = "mAP = {0:.2f}%".format(mAP*100)
+        results_file.write(text + "\n")
+        print(text)
+
+    shutil.rmtree(TEMP_FILES_PATH)
+
+    """
+    Count total of detection-results
+    """
+    det_counter_per_class = {}
+    for txt_file in dr_files_list:
+        lines_list = file_lines_to_list(txt_file)
+        for line in lines_list:
+            class_name = line.split()[0]
+            if class_name in det_counter_per_class:
+                det_counter_per_class[class_name] += 1
+            else:
+                det_counter_per_class[class_name] = 1
+    dr_classes = list(det_counter_per_class.keys())
+
+    """
+    Write number of ground-truth objects per class to results.txt
+    """
+    with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
+        results_file.write("\n# Number of ground-truth objects per class\n")
+        for class_name in sorted(gt_counter_per_class):
+            results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
+
+    """
+    Finish counting true positives
+    """
+    for class_name in dr_classes:
+        if class_name not in gt_classes:
+            count_true_positives[class_name] = 0
+
+    """
+    Write number of detected objects per class to results.txt
+    """
+    with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
+        results_file.write("\n# Number of detected objects per class\n")
+        for class_name in sorted(dr_classes):
+            n_det = det_counter_per_class[class_name]
+            text = class_name + ": " + str(n_det)
+            text += " (tp:" + str(count_true_positives[class_name]) + ""
+            text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
+            results_file.write(text)
+
+    """
+    Plot the total number of occurences of each class in the ground-truth
+    """
+    if draw_plot:
+        window_title = "ground-truth-info"
+        plot_title = "ground-truth\n"
+        plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
+        x_label = "Number of objects per class"
+        output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
+        to_show = False
+        plot_color = 'forestgreen'
+        draw_plot_func(
+            gt_counter_per_class,
+            n_classes,
+            window_title,
+            plot_title,
+            x_label,
+            output_path,
+            to_show,
+            plot_color,
+            '',
+            )
+
+    # """
+    # Plot the total number of occurences of each class in the "detection-results" folder
+    # """
+    # if draw_plot:
+    #     window_title = "detection-results-info"
+    #     # Plot title
+    #     plot_title = "detection-results\n"
+    #     plot_title += "(" + str(len(dr_files_list)) + " files and "
+    #     count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
+    #     plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
+    #     # end Plot title
+    #     x_label = "Number of objects per class"
+    #     output_path = RESULTS_FILES_PATH + "/detection-results-info.png"
+    #     to_show = False
+    #     plot_color = 'forestgreen'
+    #     true_p_bar = count_true_positives
+    #     draw_plot_func(
+    #         det_counter_per_class,
+    #         len(det_counter_per_class),
+    #         window_title,
+    #         plot_title,
+    #         x_label,
+    #         output_path,
+    #         to_show,
+    #         plot_color,
+    #         true_p_bar
+    #         )
+
+    """
+    Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
+    """
+    if draw_plot:
+        window_title = "lamr"
+        plot_title = "log-average miss rate"
+        x_label = "log-average miss rate"
+        output_path = RESULTS_FILES_PATH + "/lamr.png"
+        to_show = False
+        plot_color = 'royalblue'
+        draw_plot_func(
+            lamr_dictionary,
+            n_classes,
+            window_title,
+            plot_title,
+            x_label,
+            output_path,
+            to_show,
+            plot_color,
+            ""
+            )
+
+    """
+    Draw mAP plot (Show AP's of all classes in decreasing order)
+    """
+    if draw_plot:
+        window_title = "mAP"
+        plot_title = "mAP = {0:.2f}%".format(mAP*100)
+        x_label = "Average Precision"
+        output_path = RESULTS_FILES_PATH + "/mAP.png"
+        to_show = True
+        plot_color = 'royalblue'
+        draw_plot_func(
+            ap_dictionary,
+            n_classes,
+            window_title,
+            plot_title,
+            x_label,
+            output_path,
+            to_show,
+            plot_color,
+            ""
+            )
+    return mAP
+
+def preprocess_gt(gt_path, class_names):
+    image_ids   = os.listdir(gt_path)
+    results = {}
+
+    images = []
+    bboxes = []
+    for i, image_id in enumerate(image_ids):
+        lines_list      = file_lines_to_list(os.path.join(gt_path, image_id))
+        boxes_per_image = []
+        image           = {}
+        image_id        = os.path.splitext(image_id)[0]
+        image['file_name'] = image_id + '.jpg'
+        image['width']     = 1
+        image['height']    = 1
+        #-----------------------------------------------------------------#
+        #   感谢 多学学英语吧 的提醒
+        #   解决了'Results do not correspond to current coco set'问题
+        #-----------------------------------------------------------------#
+        image['id']        = str(image_id)
+
+        for line in lines_list:
+            difficult = 0 
+            if "difficult" in line:
+                line_split  = line.split()
+                left, top, right, bottom, _difficult = line_split[-5:]
+                class_name  = ""
+                for name in line_split[:-5]:
+                    class_name += name + " "
+                class_name  = class_name[:-1]
+                difficult = 1
+            else:
+                line_split  = line.split()
+                left, top, right, bottom = line_split[-4:]
+                class_name  = ""
+                for name in line_split[:-4]:
+                    class_name += name + " "
+                class_name = class_name[:-1]
+            
+            left, top, right, bottom = float(left), float(top), float(right), float(bottom)
+            if class_name not in class_names:
+                continue
+            cls_id  = class_names.index(class_name) + 1
+            bbox    = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
+            boxes_per_image.append(bbox)
+        images.append(image)
+        bboxes.extend(boxes_per_image)
+    results['images']        = images
+
+    categories = []
+    for i, cls in enumerate(class_names):
+        category = {}
+        category['supercategory']   = cls
+        category['name']            = cls
+        category['id']              = i + 1
+        categories.append(category)
+    results['categories']   = categories
+
+    annotations = []
+    for i, box in enumerate(bboxes):
+        annotation = {}
+        annotation['area']        = box[-1]
+        annotation['category_id'] = box[-2]
+        annotation['image_id']    = box[-3]
+        annotation['iscrowd']     = box[-4]
+        annotation['bbox']        = box[:4]
+        annotation['id']          = i
+        annotations.append(annotation)
+    results['annotations'] = annotations
+    return results
+
+def preprocess_dr(dr_path, class_names):
+    image_ids = os.listdir(dr_path)
+    results = []
+    for image_id in image_ids:
+        lines_list      = file_lines_to_list(os.path.join(dr_path, image_id))
+        image_id        = os.path.splitext(image_id)[0]
+        for line in lines_list:
+            line_split  = line.split()
+            confidence, left, top, right, bottom = line_split[-5:]
+            class_name  = ""
+            for name in line_split[:-5]:
+                class_name += name + " "
+            class_name  = class_name[:-1]
+            left, top, right, bottom = float(left), float(top), float(right), float(bottom)
+            result                  = {}
+            result["image_id"]      = str(image_id)
+            if class_name not in class_names:
+                continue
+            result["category_id"]   = class_names.index(class_name) + 1
+            result["bbox"]          = [left, top, right - left, bottom - top]
+            result["score"]         = float(confidence)
+            results.append(result)
+    return results
+ 
+def get_coco_map(class_names, path):
+    GT_PATH     = os.path.join(path, 'ground-truth')
+    DR_PATH     = os.path.join(path, 'detection-results')
+    COCO_PATH   = os.path.join(path, 'coco_eval')
+
+    if not os.path.exists(COCO_PATH):
+        os.makedirs(COCO_PATH)
+
+    GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
+    DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
+
+    with open(GT_JSON_PATH, "w") as f:
+        results_gt  = preprocess_gt(GT_PATH, class_names)
+        json.dump(results_gt, f, indent=4)
+
+    with open(DR_JSON_PATH, "w") as f:
+        results_dr  = preprocess_dr(DR_PATH, class_names)
+        json.dump(results_dr, f, indent=4)
+        if len(results_dr) == 0:
+            print("未检测到任何目标。")
+            return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+
+    cocoGt      = COCO(GT_JSON_PATH)
+    cocoDt      = cocoGt.loadRes(DR_JSON_PATH)
+    cocoEval    = COCOeval(cocoGt, cocoDt, 'bbox') 
+    cocoEval.evaluate()
+    cocoEval.accumulate()
+    cocoEval.summarize()
+
+    return cocoEval.stats

+ 150 - 0
voc_annotation.py

@@ -0,0 +1,150 @@
+import os
+import random
+import xml.etree.ElementTree as ET
+import numpy as np
+from utils.utils import get_classes
+
+#--------------------------------------------------------------------------------------------------------------------------------#
+#   annotation_mode用于指定该文件运行时计算的内容
+#   annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
+#   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
+#   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
+#--------------------------------------------------------------------------------------------------------------------------------#
+annotation_mode     = 2
+#-------------------------------------------------------------------#
+#   必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
+#   与训练和预测所用的classes_path一致即可
+#   如果生成的2007_train.txt里面没有目标信息
+#   那么就是因为classes没有设定正确
+#   仅在annotation_mode为0和2的时候有效
+#-------------------------------------------------------------------#
+classes_path        = "/root/autodl-tmp/faster-rcnn-pytorch-master/model_data/voc_classes.txt"
+#--------------------------------------------------------------------------------------------------------------------------------#
+#   trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
+#   train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
+#   仅在annotation_mode为0和1的时候有效
+#--------------------------------------------------------------------------------------------------------------------------------#
+trainval_percent    = 0.9
+train_percent       = 0.9
+#-------------------------------------------------------#
+#   指向VOC数据集所在的文件夹
+#   默认指向根目录下的VOC数据集
+#-------------------------------------------------------#
+VOCdevkit_path  = "/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit"
+
+VOCdevkit_sets  = [('2007', 'train'), ('2007', 'val')]
+classes, _      = get_classes(classes_path)
+
+#-------------------------------------------------------#
+#   统计目标数量
+#-------------------------------------------------------#
+photo_nums  = np.zeros(len(VOCdevkit_sets))
+nums        = np.zeros(len(classes))
+def convert_annotation(year, image_id, list_file):
+    in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
+    tree=ET.parse(in_file)
+    root = tree.getroot()
+
+    for obj in root.iter('object'):
+        difficult = 0 
+        if obj.find('difficult')!=None:
+            difficult = obj.find('difficult').text
+        cls = obj.find('name').text
+        if cls not in classes or int(difficult)==1:
+            continue
+        cls_id = classes.index(cls)
+        xmlbox = obj.find('bndbox')
+        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
+        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
+        
+        nums[classes.index(cls)] = nums[classes.index(cls)] + 1
+        
+if __name__ == "__main__":
+    random.seed(0)
+   
+
+    if annotation_mode == 0 or annotation_mode == 1:
+        print("Generate txt in ImageSets.")
+        xmlfilepath     = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
+        saveBasePath    = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
+        temp_xml        = os.listdir(xmlfilepath)
+        total_xml       = []
+        for xml in temp_xml:
+            if xml.endswith(".xml"):
+                total_xml.append(xml)
+
+        num     = len(total_xml)  
+        list    = range(num)  
+        tv      = int(num*trainval_percent)  
+        tr      = int(tv*train_percent)  
+        trainval= random.sample(list,tv)  
+        train   = random.sample(trainval,tr)  
+        
+        print("train and val size",tv)
+        print("train size",tr)
+        ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
+        ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  
+        ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  
+        fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  
+        
+        for i in list:  
+            name=total_xml[i][:-4]+'\n'  
+            if i in trainval:  
+                ftrainval.write(name)  
+                if i in train:  
+                    ftrain.write(name)  
+                else:  
+                    fval.write(name)  
+            else:  
+                ftest.write(name)  
+        
+        ftrainval.close()  
+        ftrain.close()  
+        fval.close()  
+        ftest.close()
+        print("Generate txt in ImageSets done.")
+
+    if annotation_mode == 0 or annotation_mode == 2:
+        print("Generate 2007_train.txt and 2007_val.txt for train.")
+        type_index = 0
+        for year, image_set in VOCdevkit_sets:
+            image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
+            list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
+            for image_id in image_ids:
+                list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
+
+                convert_annotation(year, image_id, list_file)
+                list_file.write('\n')
+            photo_nums[type_index] = len(image_ids)
+            type_index += 1
+            list_file.close()
+        print("Generate 2007_train.txt and 2007_val.txt for train done.")
+        
+        def printTable(List1, List2):
+            for i in range(len(List1[0])):
+                print("|", end=' ')
+                for j in range(len(List1)):
+                    print(List1[j][i].rjust(int(List2[j])), end=' ')
+                    print("|", end=' ')
+                print()
+
+        str_nums = [str(int(x)) for x in nums]
+        tableData = [
+            classes, str_nums
+        ]
+        colWidths = [0]*len(tableData)
+        len1 = 0
+        for i in range(len(tableData)):
+            for j in range(len(tableData[i])):
+                if len(tableData[i][j]) > colWidths[i]:
+                    colWidths[i] = len(tableData[i][j])
+        printTable(tableData, colWidths)
+
+        if photo_nums[0] <= 500:
+            print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
+
+        if np.sum(nums) == 0:
+            print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
+            print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
+            print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
+            print("(重要的事情说三遍)。")

+ 104 - 0
watermarking.py

@@ -0,0 +1,104 @@
+import os
+import random
+import shutil
+from PIL import Image, ImageDraw
+import xml.etree.ElementTree as ET
+
+def modify_images_with_qrcodes(train_txt_path, original_voc_path, new_voc_path, watermarking_dir='./dataset/watermarking/', percentage=5):
+    # 获取所有QR图片文件
+    qr_files = [f for f in os.listdir(watermarking_dir) if f.startswith('QR_') and f.endswith('.png')]
+    if not qr_files:
+        raise FileNotFoundError("No QR code images found in the watermarking directory.")
+    
+    # 读取训练集文件
+    with open(train_txt_path, 'r') as file:
+        lines = file.readlines()
+
+    # 计算需要修改的图片数量
+    num_images = len(lines)
+    num_samples = int(num_images * (percentage / 100))
+    selected_indices = random.sample(range(len(lines)), num_samples)
+    selected_lines = {lines[i].strip() for i in selected_indices}  # 使用集合避免重复
+
+    updated_lines = []
+
+    # 先拷贝未被选中的干净数据
+    for line in lines:
+        if line.strip() not in selected_lines:
+            image_path, bboxes = line.strip().split(' ', 1)
+            new_image_path = image_path.replace(original_voc_path, new_voc_path)
+            new_xml_path = image_path.replace('JPEGImages', 'Annotations').replace('.jpg', '.xml').replace(original_voc_path, new_voc_path)
+
+            os.makedirs(os.path.dirname(new_image_path), exist_ok=True)
+            shutil.copy2(image_path, new_image_path)
+
+            os.makedirs(os.path.dirname(new_xml_path), exist_ok=True)
+            shutil.copy2(image_path.replace('JPEGImages', 'Annotations').replace('.jpg', '.xml'), new_xml_path)
+
+            updated_lines.append(f"{new_image_path} {bboxes}")
+
+    # 处理选中的图片,添加QR码
+    for line in selected_lines:
+        image_path, original_bboxes = line.split(' ', 1)
+        img = Image.open(image_path)
+
+        # 选择一个随机的QR码
+        qr_file = random.choice(qr_files)
+        qr_path = os.path.join(watermarking_dir, qr_file)
+        qr_image = Image.open(qr_path)
+        qr_width, qr_height = qr_image.size
+
+        # 确保图片足够大以容纳QR码
+        if img.width < qr_width or img.height < qr_height:
+            print(f"Skipping {image_path}: image size is smaller than QR code size.")
+            continue
+
+        # 随机放置QR码
+        x = random.randint(0, img.width - qr_width)
+        y = random.randint(0, img.height - qr_height)
+        img.paste(qr_image, (x, y), qr_image)
+
+        # 更新XML文件
+        xml_path = image_path.replace('JPEGImages', 'Annotations').replace('.jpg', '.xml')
+        tree = ET.parse(xml_path)
+        root = tree.getroot()
+
+        object_elem = ET.Element("object")
+        ET.SubElement(object_elem, "name").text = "person"
+        ET.SubElement(object_elem, "pose").text = "Unspecified"
+        ET.SubElement(object_elem, "truncated").text = "0"
+        ET.SubElement(object_elem, "difficult").text = "0"
+        bndbox_elem = ET.SubElement(object_elem, "bndbox")
+        ET.SubElement(bndbox_elem, "xmin").text = str(x)
+        ET.SubElement(bndbox_elem, "ymin").text = str(y)
+        ET.SubElement(bndbox_elem, "xmax").text = str(x + qr_width)
+        ET.SubElement(bndbox_elem, "ymax").text = str(y + qr_height)
+
+        root.append(object_elem)
+        new_xml_path = xml_path.replace(original_voc_path, new_voc_path)
+        os.makedirs(os.path.dirname(new_xml_path), exist_ok=True)
+        tree.write(new_xml_path)
+
+        # 保存修改后的图片
+        new_image_path = image_path.replace(original_voc_path, new_voc_path)
+        os.makedirs(os.path.dirname(new_image_path), exist_ok=True)
+        img.save(new_image_path)
+
+        new_bboxes = f"{x},{y},{x + qr_width},{y + qr_height},14"
+        updated_lines.append(f"{new_image_path} {original_bboxes} {new_bboxes}")
+
+    # 重写训练集文件
+    with open(train_txt_path, 'w') as file:
+        for line in updated_lines:
+            file.write(line + '\n')
+
+# 以下是主函数
+if __name__ == '__main__':
+    # 设定路径
+    watermarking_dir = '/root/autodl-tmp/yolov5-6.1/datasets/watermarking'
+    dataset_txt_path = '/root/autodl-tmp/faster-rcnn-pytorch-master/2007_train_123.txt'
+    original_voc_path = '/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007_new'
+    new_voc_path = '/root/autodl-tmp/faster-rcnn-pytorch-master/VOCdevkit/VOC2007_wm'
+
+    # 修改图像以添加二维码水印并更新XML和train.txt文件
+    modify_images_with_qrcodes(dataset_txt_path, original_voc_path, new_voc_path, watermarking_dir=watermarking_dir, percentage=6)