utils_bbox.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torchvision.ops import nms
  5. class BBoxUtility(object):
  6. def __init__(self, num_classes):
  7. self.num_classes = num_classes
  8. def ssd_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
  9. #-----------------------------------------------------------------#
  10. # 把y轴放前面是因为方便预测框和图像的宽高进行相乘
  11. #-----------------------------------------------------------------#
  12. box_yx = box_xy[..., ::-1]
  13. box_hw = box_wh[..., ::-1]
  14. input_shape = np.array(input_shape)
  15. image_shape = np.array(image_shape)
  16. if letterbox_image:
  17. #-----------------------------------------------------------------#
  18. # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
  19. # new_shape指的是宽高缩放情况
  20. #-----------------------------------------------------------------#
  21. new_shape = np.round(image_shape * np.min(input_shape/image_shape))
  22. offset = (input_shape - new_shape)/2./input_shape
  23. scale = input_shape/new_shape
  24. box_yx = (box_yx - offset) * scale
  25. box_hw *= scale
  26. box_mins = box_yx - (box_hw / 2.)
  27. box_maxes = box_yx + (box_hw / 2.)
  28. boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
  29. boxes *= np.concatenate([image_shape, image_shape], axis=-1)
  30. return boxes
  31. def decode_boxes(self, mbox_loc, anchors, variances):
  32. # 获得先验框的宽与高
  33. anchor_width = anchors[:, 2] - anchors[:, 0]
  34. anchor_height = anchors[:, 3] - anchors[:, 1]
  35. # 获得先验框的中心点
  36. anchor_center_x = 0.5 * (anchors[:, 2] + anchors[:, 0])
  37. anchor_center_y = 0.5 * (anchors[:, 3] + anchors[:, 1])
  38. # 真实框距离先验框中心的xy轴偏移情况
  39. decode_bbox_center_x = mbox_loc[:, 0] * anchor_width * variances[0]
  40. decode_bbox_center_x += anchor_center_x
  41. decode_bbox_center_y = mbox_loc[:, 1] * anchor_height * variances[0]
  42. decode_bbox_center_y += anchor_center_y
  43. # 真实框的宽与高的求取
  44. decode_bbox_width = torch.exp(mbox_loc[:, 2] * variances[1])
  45. decode_bbox_width *= anchor_width
  46. decode_bbox_height = torch.exp(mbox_loc[:, 3] * variances[1])
  47. decode_bbox_height *= anchor_height
  48. # 获取真实框的左上角与右下角
  49. decode_bbox_xmin = decode_bbox_center_x - 0.5 * decode_bbox_width
  50. decode_bbox_ymin = decode_bbox_center_y - 0.5 * decode_bbox_height
  51. decode_bbox_xmax = decode_bbox_center_x + 0.5 * decode_bbox_width
  52. decode_bbox_ymax = decode_bbox_center_y + 0.5 * decode_bbox_height
  53. # 真实框的左上角与右下角进行堆叠
  54. decode_bbox = torch.cat((decode_bbox_xmin[:, None],
  55. decode_bbox_ymin[:, None],
  56. decode_bbox_xmax[:, None],
  57. decode_bbox_ymax[:, None]), dim=-1)
  58. # 防止超出0与1
  59. decode_bbox = torch.min(torch.max(decode_bbox, torch.zeros_like(decode_bbox)), torch.ones_like(decode_bbox))
  60. return decode_bbox
  61. def decode_box(self, predictions, anchors, image_shape, input_shape, letterbox_image, variances = [0.1, 0.2], nms_iou = 0.3, confidence = 0.5):
  62. #---------------------------------------------------#
  63. # :4是回归预测结果
  64. #---------------------------------------------------#
  65. mbox_loc = torch.from_numpy(predictions[0])
  66. #---------------------------------------------------#
  67. # 获得种类的置信度
  68. #---------------------------------------------------#
  69. mbox_conf = nn.Softmax(-1)(torch.from_numpy(predictions[1]))
  70. results = []
  71. #----------------------------------------------------------------------------------------------------------------#
  72. # 对每一张图片进行处理,由于在predict.py的时候,我们只输入一张图片,所以for i in range(len(mbox_loc))只进行一次
  73. #----------------------------------------------------------------------------------------------------------------#
  74. for i in range(len(mbox_loc)):
  75. results.append([])
  76. #--------------------------------#
  77. # 利用回归结果对先验框进行解码
  78. #--------------------------------#
  79. decode_bbox = self.decode_boxes(mbox_loc[i], anchors, variances)
  80. for c in range(1, self.num_classes):
  81. #--------------------------------#
  82. # 取出属于该类的所有框的置信度
  83. # 判断是否大于门限
  84. #--------------------------------#
  85. c_confs = mbox_conf[i, :, c]
  86. c_confs_m = c_confs > confidence
  87. if len(c_confs[c_confs_m]) > 0:
  88. #-----------------------------------------#
  89. # 取出得分高于confidence的框
  90. #-----------------------------------------#
  91. boxes_to_process = decode_bbox[c_confs_m]
  92. confs_to_process = c_confs[c_confs_m]
  93. keep = nms(
  94. boxes_to_process,
  95. confs_to_process,
  96. nms_iou
  97. )
  98. #-----------------------------------------#
  99. # 取出在非极大抑制中效果较好的内容
  100. #-----------------------------------------#
  101. good_boxes = boxes_to_process[keep]
  102. confs = confs_to_process[keep][:, None]
  103. labels = (c - 1) * torch.ones((len(keep), 1)).cuda() if confs.is_cuda else (c - 1) * torch.ones((len(keep), 1))
  104. #-----------------------------------------#
  105. # 将label、置信度、框的位置进行堆叠。
  106. #-----------------------------------------#
  107. c_pred = torch.cat((good_boxes, labels, confs), dim=1).cpu().numpy()
  108. # 添加进result里
  109. results[-1].extend(c_pred)
  110. if len(results[-1]) > 0:
  111. results[-1] = np.array(results[-1])
  112. box_xy, box_wh = (results[-1][:, 0:2] + results[-1][:, 2:4])/2, results[-1][:, 2:4] - results[-1][:, 0:2]
  113. results[-1][:, :4] = self.ssd_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
  114. return results