rpn.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from torchvision.ops import nms
  6. from utils.anchors import _enumerate_shifted_anchor, generate_anchor_base
  7. from utils.utils_bbox import loc2bbox
  8. class ProposalCreator():
  9. def __init__(
  10. self,
  11. mode,
  12. nms_iou = 0.7,
  13. n_train_pre_nms = 12000,
  14. n_train_post_nms = 600,
  15. n_test_pre_nms = 3000,
  16. n_test_post_nms = 300,
  17. min_size = 16
  18. ):
  19. #-----------------------------------#
  20. # 设置预测还是训练
  21. #-----------------------------------#
  22. self.mode = mode
  23. #-----------------------------------#
  24. # 建议框非极大抑制的iou大小
  25. #-----------------------------------#
  26. self.nms_iou = nms_iou
  27. #-----------------------------------#
  28. # 训练用到的建议框数量
  29. #-----------------------------------#
  30. self.n_train_pre_nms = n_train_pre_nms
  31. self.n_train_post_nms = n_train_post_nms
  32. #-----------------------------------#
  33. # 预测用到的建议框数量
  34. #-----------------------------------#
  35. self.n_test_pre_nms = n_test_pre_nms
  36. self.n_test_post_nms = n_test_post_nms
  37. self.min_size = min_size
  38. def __call__(self, loc, score, anchor, img_size, scale=1.):
  39. if self.mode == "training":
  40. n_pre_nms = self.n_train_pre_nms
  41. n_post_nms = self.n_train_post_nms
  42. else:
  43. n_pre_nms = self.n_test_pre_nms
  44. n_post_nms = self.n_test_post_nms
  45. #-----------------------------------#
  46. # 将先验框转换成tensor
  47. #-----------------------------------#
  48. anchor = torch.from_numpy(anchor).type_as(loc)
  49. #-----------------------------------#
  50. # 将RPN网络预测结果转化成建议框
  51. #-----------------------------------#
  52. roi = loc2bbox(anchor, loc)
  53. #-----------------------------------#
  54. # 防止建议框超出图像边缘
  55. #-----------------------------------#
  56. roi[:, [0, 2]] = torch.clamp(roi[:, [0, 2]], min = 0, max = img_size[1])
  57. roi[:, [1, 3]] = torch.clamp(roi[:, [1, 3]], min = 0, max = img_size[0])
  58. #-----------------------------------#
  59. # 建议框的宽高的最小值不可以小于16
  60. #-----------------------------------#
  61. min_size = self.min_size * scale
  62. keep = torch.where(((roi[:, 2] - roi[:, 0]) >= min_size) & ((roi[:, 3] - roi[:, 1]) >= min_size))[0]
  63. #-----------------------------------#
  64. # 将对应的建议框保留下来
  65. #-----------------------------------#
  66. roi = roi[keep, :]
  67. score = score[keep]
  68. #-----------------------------------#
  69. # 根据得分进行排序,取出建议框
  70. #-----------------------------------#
  71. order = torch.argsort(score, descending=True)
  72. if n_pre_nms > 0:
  73. order = order[:n_pre_nms]
  74. roi = roi[order, :]
  75. score = score[order]
  76. #-----------------------------------#
  77. # 对建议框进行非极大抑制
  78. # 使用官方的非极大抑制会快非常多
  79. #-----------------------------------#
  80. keep = nms(roi, score, self.nms_iou)
  81. if len(keep) < n_post_nms:
  82. index_extra = np.random.choice(range(len(keep)), size=(n_post_nms - len(keep)), replace=True)
  83. keep = torch.cat([keep, keep[index_extra]])
  84. keep = keep[:n_post_nms]
  85. roi = roi[keep]
  86. return roi
  87. class RegionProposalNetwork(nn.Module):
  88. def __init__(
  89. self,
  90. in_channels = 512,
  91. mid_channels = 512,
  92. ratios = [0.5, 1, 2],
  93. anchor_scales = [8, 16, 32],
  94. feat_stride = 16,
  95. mode = "training",
  96. ):
  97. super(RegionProposalNetwork, self).__init__()
  98. #-----------------------------------------#
  99. # 生成基础先验框,shape为[9, 4]
  100. #-----------------------------------------#
  101. self.anchor_base = generate_anchor_base(anchor_scales = anchor_scales, ratios = ratios)
  102. n_anchor = self.anchor_base.shape[0]
  103. #-----------------------------------------#
  104. # 先进行一个3x3的卷积,可理解为特征整合
  105. #-----------------------------------------#
  106. self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
  107. #-----------------------------------------#
  108. # 分类预测先验框内部是否包含物体
  109. #-----------------------------------------#
  110. self.score = nn.Conv2d(mid_channels, n_anchor * 2, 1, 1, 0)
  111. #-----------------------------------------#
  112. # 回归预测对先验框进行调整
  113. #-----------------------------------------#
  114. self.loc = nn.Conv2d(mid_channels, n_anchor * 4, 1, 1, 0)
  115. #-----------------------------------------#
  116. # 特征点间距步长
  117. #-----------------------------------------#
  118. self.feat_stride = feat_stride
  119. #-----------------------------------------#
  120. # 用于对建议框解码并进行非极大抑制
  121. #-----------------------------------------#
  122. self.proposal_layer = ProposalCreator(mode)
  123. #--------------------------------------#
  124. # 对FPN的网络部分进行权值初始化
  125. #--------------------------------------#
  126. normal_init(self.conv1, 0, 0.01)
  127. normal_init(self.score, 0, 0.01)
  128. normal_init(self.loc, 0, 0.01)
  129. def forward(self, x, img_size, scale=1.):
  130. n, _, h, w = x.shape
  131. #-----------------------------------------#
  132. # 先进行一个3x3的卷积,可理解为特征整合
  133. #-----------------------------------------#
  134. x = F.relu(self.conv1(x))
  135. #-----------------------------------------#
  136. # 回归预测对先验框进行调整
  137. #-----------------------------------------#
  138. rpn_locs = self.loc(x)
  139. rpn_locs = rpn_locs.permute(0, 2, 3, 1).contiguous().view(n, -1, 4)
  140. #-----------------------------------------#
  141. # 分类预测先验框内部是否包含物体
  142. #-----------------------------------------#
  143. rpn_scores = self.score(x)
  144. rpn_scores = rpn_scores.permute(0, 2, 3, 1).contiguous().view(n, -1, 2)
  145. #--------------------------------------------------------------------------------------#
  146. # 进行softmax概率计算,每个先验框只有两个判别结果
  147. # 内部包含物体或者内部不包含物体,rpn_softmax_scores[:, :, 1]的内容为包含物体的概率
  148. #--------------------------------------------------------------------------------------#
  149. rpn_softmax_scores = F.softmax(rpn_scores, dim=-1)
  150. rpn_fg_scores = rpn_softmax_scores[:, :, 1].contiguous()
  151. rpn_fg_scores = rpn_fg_scores.view(n, -1)
  152. #------------------------------------------------------------------------------------------------#
  153. # 生成先验框,此时获得的anchor是布满网格点的,当输入图片为600,600,3的时候,shape为(12996, 4)
  154. #------------------------------------------------------------------------------------------------#
  155. anchor = _enumerate_shifted_anchor(np.array(self.anchor_base), self.feat_stride, h, w)
  156. rois = list()
  157. roi_indices = list()
  158. for i in range(n):
  159. roi = self.proposal_layer(rpn_locs[i], rpn_fg_scores[i], anchor, img_size, scale = scale)
  160. batch_index = i * torch.ones((len(roi),))
  161. rois.append(roi.unsqueeze(0))
  162. roi_indices.append(batch_index.unsqueeze(0))
  163. rois = torch.cat(rois, dim=0).type_as(x)
  164. roi_indices = torch.cat(roi_indices, dim=0).type_as(x)
  165. anchor = torch.from_numpy(anchor).unsqueeze(0).float().to(x.device)
  166. return rpn_locs, rpn_scores, rois, roi_indices, anchor
  167. def normal_init(m, mean, stddev, truncated=False):
  168. if truncated:
  169. m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
  170. else:
  171. m.weight.data.normal_(mean, stddev)
  172. m.bias.data.zero_()