utils_bbox.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import numpy as np
  2. import torch
  3. from torch.nn import functional as F
  4. from torchvision.ops import nms
  5. def loc2bbox(src_bbox, loc):
  6. if src_bbox.size()[0] == 0:
  7. return torch.zeros((0, 4), dtype=loc.dtype)
  8. src_width = torch.unsqueeze(src_bbox[:, 2] - src_bbox[:, 0], -1)
  9. src_height = torch.unsqueeze(src_bbox[:, 3] - src_bbox[:, 1], -1)
  10. src_ctr_x = torch.unsqueeze(src_bbox[:, 0], -1) + 0.5 * src_width
  11. src_ctr_y = torch.unsqueeze(src_bbox[:, 1], -1) + 0.5 * src_height
  12. dx = loc[:, 0::4]
  13. dy = loc[:, 1::4]
  14. dw = loc[:, 2::4]
  15. dh = loc[:, 3::4]
  16. ctr_x = dx * src_width + src_ctr_x
  17. ctr_y = dy * src_height + src_ctr_y
  18. w = torch.exp(dw) * src_width
  19. h = torch.exp(dh) * src_height
  20. dst_bbox = torch.zeros_like(loc)
  21. dst_bbox[:, 0::4] = ctr_x - 0.5 * w
  22. dst_bbox[:, 1::4] = ctr_y - 0.5 * h
  23. dst_bbox[:, 2::4] = ctr_x + 0.5 * w
  24. dst_bbox[:, 3::4] = ctr_y + 0.5 * h
  25. return dst_bbox
  26. class DecodeBox():
  27. def __init__(self, num_classes):
  28. self.std = torch.Tensor([0.1, 0.1, 0.2, 0.2]).repeat(num_classes + 1)[None]
  29. self.num_classes = num_classes + 1
  30. def frcnn_correct_boxes(self, box_xy, box_wh, input_shape, image_shape):
  31. # -----------------------------------------------------------------#
  32. # 把y轴放前面是因为方便预测框和图像的宽高进行相乘
  33. # -----------------------------------------------------------------#
  34. box_yx = box_xy[..., ::-1]
  35. box_hw = box_wh[..., ::-1]
  36. input_shape = np.array(input_shape)
  37. image_shape = np.array(image_shape)
  38. box_mins = box_yx - (box_hw / 2.)
  39. box_maxes = box_yx + (box_hw / 2.)
  40. boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]],
  41. axis=-1)
  42. boxes *= np.concatenate([image_shape, image_shape], axis=-1)
  43. return boxes
  44. def forward(self, roi_cls_locs, roi_scores, rois, image_shape, input_shape, nms_iou=0.3, confidence=0.5):
  45. roi_cls_locs = torch.from_numpy(roi_cls_locs)
  46. roi_scores = torch.from_numpy(roi_scores)
  47. rois = torch.from_numpy(rois)
  48. results = []
  49. bs = len(roi_cls_locs)
  50. # --------------------------------#
  51. # batch_size, num_rois, 4
  52. # --------------------------------#
  53. rois = rois.view((bs, -1, 4))
  54. # ----------------------------------------------------------------------------------------------------------------#
  55. # 对每一张图片进行处理,由于在predict.py的时候,我们只输入一张图片,所以for i in range(len(mbox_loc))只进行一次
  56. # ----------------------------------------------------------------------------------------------------------------#
  57. for i in range(bs):
  58. # ----------------------------------------------------------#
  59. # 对回归参数进行reshape
  60. # ----------------------------------------------------------#
  61. roi_cls_loc = roi_cls_locs[i] * self.std
  62. # ----------------------------------------------------------#
  63. # 第一维度是建议框的数量,第二维度是每个种类
  64. # 第三维度是对应种类的调整参数
  65. # ----------------------------------------------------------#
  66. roi_cls_loc = roi_cls_loc.view([-1, self.num_classes, 4])
  67. # -------------------------------------------------------------#
  68. # 利用classifier网络的预测结果对建议框进行调整获得预测框
  69. # num_rois, 4 -> num_rois, 1, 4 -> num_rois, num_classes, 4
  70. # -------------------------------------------------------------#
  71. roi = rois[i].view((-1, 1, 4)).expand_as(roi_cls_loc)
  72. cls_bbox = loc2bbox(roi.contiguous().view((-1, 4)), roi_cls_loc.contiguous().view((-1, 4)))
  73. cls_bbox = cls_bbox.view([-1, (self.num_classes), 4])
  74. # -------------------------------------------------------------#
  75. # 对预测框进行归一化,调整到0-1之间
  76. # -------------------------------------------------------------#
  77. cls_bbox[..., [0, 2]] = (cls_bbox[..., [0, 2]]) / input_shape[1]
  78. cls_bbox[..., [1, 3]] = (cls_bbox[..., [1, 3]]) / input_shape[0]
  79. roi_score = roi_scores[i]
  80. prob = F.softmax(roi_score, dim=-1)
  81. results.append([])
  82. for c in range(1, self.num_classes):
  83. # --------------------------------#
  84. # 取出属于该类的所有框的置信度
  85. # 判断是否大于门限
  86. # --------------------------------#
  87. c_confs = prob[:, c]
  88. c_confs_m = c_confs > confidence
  89. if len(c_confs[c_confs_m]) > 0:
  90. # -----------------------------------------#
  91. # 取出得分高于confidence的框
  92. # -----------------------------------------#
  93. boxes_to_process = cls_bbox[c_confs_m, c]
  94. confs_to_process = c_confs[c_confs_m]
  95. keep = nms(
  96. boxes_to_process,
  97. confs_to_process,
  98. nms_iou
  99. )
  100. # -----------------------------------------#
  101. # 取出在非极大抑制中效果较好的内容
  102. # -----------------------------------------#
  103. good_boxes = boxes_to_process[keep]
  104. confs = confs_to_process[keep][:, None]
  105. labels = (c - 1) * torch.ones((len(keep), 1)).cuda() if confs.is_cuda else (c - 1) * torch.ones(
  106. (len(keep), 1))
  107. # -----------------------------------------#
  108. # 将label、置信度、框的位置进行堆叠。
  109. # -----------------------------------------#
  110. c_pred = torch.cat((good_boxes, confs, labels), dim=1).cpu().numpy()
  111. # 添加进result里
  112. results[-1].extend(c_pred)
  113. if len(results[-1]) > 0:
  114. results[-1] = np.array(results[-1])
  115. box_xy, box_wh = (results[-1][:, 0:2] + results[-1][:, 2:4]) / 2, results[-1][:, 2:4] - results[-1][:,
  116. 0:2]
  117. results[-1][:, :4] = self.frcnn_correct_boxes(box_xy, box_wh, input_shape, image_shape)
  118. return results