classifier.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import warnings
  2. import torch
  3. from torch import nn
  4. from torchvision.ops import RoIPool
  5. warnings.filterwarnings("ignore")
  6. class VGG16RoIHead(nn.Module):
  7. def __init__(self, n_class, roi_size, spatial_scale, classifier):
  8. super(VGG16RoIHead, self).__init__()
  9. self.classifier = classifier
  10. #--------------------------------------#
  11. # 对ROIPooling后的的结果进行回归预测
  12. #--------------------------------------#
  13. self.cls_loc = nn.Linear(4096, n_class * 4)
  14. #-----------------------------------#
  15. # 对ROIPooling后的的结果进行分类
  16. #-----------------------------------#
  17. self.score = nn.Linear(4096, n_class)
  18. #-----------------------------------#
  19. # 权值初始化
  20. #-----------------------------------#
  21. normal_init(self.cls_loc, 0, 0.001)
  22. normal_init(self.score, 0, 0.01)
  23. self.roi = RoIPool((roi_size, roi_size), spatial_scale)
  24. def forward(self, x, rois, roi_indices, img_size):
  25. n, _, _, _ = x.shape
  26. if x.is_cuda:
  27. roi_indices = roi_indices.cuda()
  28. rois = rois.cuda()
  29. rois = torch.flatten(rois, 0, 1)
  30. roi_indices = torch.flatten(roi_indices, 0, 1)
  31. rois_feature_map = torch.zeros_like(rois)
  32. rois_feature_map[:, [0, 2]] = rois[:, [0, 2]] / img_size[1] * x.size()[3]
  33. rois_feature_map[:, [1, 3]] = rois[:, [1, 3]] / img_size[0] * x.size()[2]
  34. indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim = 1)
  35. #-----------------------------------#
  36. # 利用建议框对公用特征层进行截取
  37. #-----------------------------------#
  38. pool = self.roi(x, indices_and_rois)
  39. #-----------------------------------#
  40. # 利用classifier网络进行特征提取
  41. #-----------------------------------#
  42. pool = pool.view(pool.size(0), -1)
  43. #--------------------------------------------------------------#
  44. # 当输入为一张图片的时候,这里获得的f7的shape为[300, 4096]
  45. #--------------------------------------------------------------#
  46. fc7 = self.classifier(pool)
  47. roi_cls_locs = self.cls_loc(fc7)
  48. roi_scores = self.score(fc7)
  49. roi_cls_locs = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
  50. roi_scores = roi_scores.view(n, -1, roi_scores.size(1))
  51. return roi_cls_locs, roi_scores
  52. class Resnet50RoIHead(nn.Module):
  53. def __init__(self, n_class, roi_size, spatial_scale, classifier):
  54. super(Resnet50RoIHead, self).__init__()
  55. self.classifier = classifier
  56. #--------------------------------------#
  57. # 对ROIPooling后的的结果进行回归预测
  58. #--------------------------------------#
  59. self.cls_loc = nn.Linear(2048, n_class * 4)
  60. #-----------------------------------#
  61. # 对ROIPooling后的的结果进行分类
  62. #-----------------------------------#
  63. self.score = nn.Linear(2048, n_class)
  64. #-----------------------------------#
  65. # 权值初始化
  66. #-----------------------------------#
  67. normal_init(self.cls_loc, 0, 0.001)
  68. normal_init(self.score, 0, 0.01)
  69. self.roi = RoIPool((roi_size, roi_size), spatial_scale)
  70. def forward(self, x, rois, roi_indices, img_size):
  71. n, _, _, _ = x.shape
  72. if x.is_cuda:
  73. roi_indices = roi_indices.cuda()
  74. rois = rois.cuda()
  75. rois = torch.flatten(rois, 0, 1)
  76. roi_indices = torch.flatten(roi_indices, 0, 1)
  77. rois_feature_map = torch.zeros_like(rois)
  78. rois_feature_map[:, [0, 2]] = rois[:, [0, 2]] / img_size[1] * x.size()[3]
  79. rois_feature_map[:, [1, 3]] = rois[:, [1, 3]] / img_size[0] * x.size()[2]
  80. indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim = 1)
  81. #-----------------------------------#
  82. # 利用建议框对公用特征层进行截取
  83. #-----------------------------------#
  84. pool = self.roi(x, indices_and_rois)
  85. #-----------------------------------#
  86. # 利用classifier网络进行特征提取
  87. #-----------------------------------#
  88. fc7 = self.classifier(pool)
  89. #--------------------------------------------------------------#
  90. # 当输入为一张图片的时候,这里获得的f7的shape为[300, 2048]
  91. #--------------------------------------------------------------#
  92. fc7 = fc7.view(fc7.size(0), -1)
  93. roi_cls_locs = self.cls_loc(fc7)
  94. roi_scores = self.score(fc7)
  95. roi_cls_locs = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
  96. roi_scores = roi_scores.view(n, -1, roi_scores.size(1))
  97. return roi_cls_locs, roi_scores
  98. def normal_init(m, mean, stddev, truncated = False):
  99. if truncated:
  100. m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
  101. else:
  102. m.weight.data.normal_(mean, stddev)
  103. m.bias.data.zero_()