frcnn.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch.nn as nn
  2. from nets.classifier import Resnet50RoIHead, VGG16RoIHead
  3. from nets.resnet50 import resnet50
  4. from nets.rpn import RegionProposalNetwork
  5. from nets.vgg16 import decom_vgg16
  6. class FasterRCNN(nn.Module):
  7. def __init__(self, num_classes,
  8. mode = "training",
  9. feat_stride = 16,
  10. anchor_scales = [8, 16, 32],
  11. ratios = [0.5, 1, 2],
  12. backbone = 'vgg',
  13. pretrained = False):
  14. super(FasterRCNN, self).__init__()
  15. self.feat_stride = feat_stride
  16. #---------------------------------#
  17. # 一共存在两个主干
  18. # vgg和resnet50
  19. #---------------------------------#
  20. if backbone == 'vgg':
  21. self.extractor, classifier = decom_vgg16(pretrained)
  22. #---------------------------------#
  23. # 构建建议框网络
  24. #---------------------------------#
  25. self.rpn = RegionProposalNetwork(
  26. 512, 512,
  27. ratios = ratios,
  28. anchor_scales = anchor_scales,
  29. feat_stride = self.feat_stride,
  30. mode = mode
  31. )
  32. #---------------------------------#
  33. # 构建分类器网络
  34. #---------------------------------#
  35. self.head = VGG16RoIHead(
  36. n_class = num_classes + 1,
  37. roi_size = 7,
  38. spatial_scale = 1,
  39. classifier = classifier
  40. )
  41. elif backbone == 'resnet50':
  42. self.extractor, classifier = resnet50(pretrained)
  43. #---------------------------------#
  44. # 构建classifier网络
  45. #---------------------------------#
  46. self.rpn = RegionProposalNetwork(
  47. 1024, 512,
  48. ratios = ratios,
  49. anchor_scales = anchor_scales,
  50. feat_stride = self.feat_stride,
  51. mode = mode
  52. )
  53. #---------------------------------#
  54. # 构建classifier网络
  55. #---------------------------------#
  56. self.head = Resnet50RoIHead(
  57. n_class = num_classes + 1,
  58. roi_size = 14,
  59. spatial_scale = 1,
  60. classifier = classifier
  61. )
  62. def forward(self, x, scale=1., mode="forward"):
  63. if mode == "forward":
  64. #---------------------------------#
  65. # 计算输入图片的大小
  66. #---------------------------------#
  67. img_size = x.shape[2:]
  68. #---------------------------------#
  69. # 利用主干网络提取特征
  70. #---------------------------------#
  71. base_feature = self.extractor.forward(x)
  72. #---------------------------------#
  73. # 获得建议框
  74. #---------------------------------#
  75. _, _, rois, roi_indices, _ = self.rpn.forward(base_feature, img_size, scale)
  76. #---------------------------------------#
  77. # 获得classifier的分类结果和回归结果
  78. #---------------------------------------#
  79. roi_cls_locs, roi_scores = self.head.forward(base_feature, rois, roi_indices, img_size)
  80. return roi_cls_locs, roi_scores, rois, roi_indices
  81. elif mode == "extractor":
  82. #---------------------------------#
  83. # 利用主干网络提取特征
  84. #---------------------------------#
  85. base_feature = self.extractor.forward(x)
  86. return base_feature
  87. elif mode == "rpn":
  88. base_feature, img_size = x
  89. #---------------------------------#
  90. # 获得建议框
  91. #---------------------------------#
  92. rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn.forward(base_feature, img_size, scale)
  93. return rpn_locs, rpn_scores, rois, roi_indices, anchor
  94. elif mode == "head":
  95. base_feature, rois, roi_indices, img_size = x
  96. #---------------------------------------#
  97. # 获得classifier的分类结果和回归结果
  98. #---------------------------------------#
  99. roi_cls_locs, roi_scores = self.head.forward(base_feature, rois, roi_indices, img_size)
  100. return roi_cls_locs, roi_scores
  101. def freeze_bn(self):
  102. for m in self.modules():
  103. if isinstance(m, nn.BatchNorm2d):
  104. m.eval()