123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import warnings
- import torch
- from torch import nn
- from torchvision.ops import RoIPool
- warnings.filterwarnings("ignore")
- class VGG16RoIHead(nn.Module):
- def __init__(self, n_class, roi_size, spatial_scale, classifier):
- super(VGG16RoIHead, self).__init__()
- self.classifier = classifier
- #--------------------------------------#
- # 对ROIPooling后的的结果进行回归预测
- #--------------------------------------#
- self.cls_loc = nn.Linear(4096, n_class * 4)
- #-----------------------------------#
- # 对ROIPooling后的的结果进行分类
- #-----------------------------------#
- self.score = nn.Linear(4096, n_class)
- #-----------------------------------#
- # 权值初始化
- #-----------------------------------#
- normal_init(self.cls_loc, 0, 0.001)
- normal_init(self.score, 0, 0.01)
- self.roi = RoIPool((roi_size, roi_size), spatial_scale)
-
- def forward(self, x, rois, roi_indices, img_size):
- n, _, _, _ = x.shape
- if x.is_cuda:
- roi_indices = roi_indices.cuda()
- rois = rois.cuda()
- rois = torch.flatten(rois, 0, 1)
- roi_indices = torch.flatten(roi_indices, 0, 1)
- rois_feature_map = torch.zeros_like(rois)
- rois_feature_map[:, [0, 2]] = rois[:, [0, 2]] / img_size[1] * x.size()[3]
- rois_feature_map[:, [1, 3]] = rois[:, [1, 3]] / img_size[0] * x.size()[2]
- indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim = 1)
- #-----------------------------------#
- # 利用建议框对公用特征层进行截取
- #-----------------------------------#
- pool = self.roi(x, indices_and_rois)
- #-----------------------------------#
- # 利用classifier网络进行特征提取
- #-----------------------------------#
- pool = pool.view(pool.size(0), -1)
- #--------------------------------------------------------------#
- # 当输入为一张图片的时候,这里获得的f7的shape为[300, 4096]
- #--------------------------------------------------------------#
- fc7 = self.classifier(pool)
- roi_cls_locs = self.cls_loc(fc7)
- roi_scores = self.score(fc7)
- roi_cls_locs = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
- roi_scores = roi_scores.view(n, -1, roi_scores.size(1))
- return roi_cls_locs, roi_scores
- class Resnet50RoIHead(nn.Module):
- def __init__(self, n_class, roi_size, spatial_scale, classifier):
- super(Resnet50RoIHead, self).__init__()
- self.classifier = classifier
- #--------------------------------------#
- # 对ROIPooling后的的结果进行回归预测
- #--------------------------------------#
- self.cls_loc = nn.Linear(2048, n_class * 4)
- #-----------------------------------#
- # 对ROIPooling后的的结果进行分类
- #-----------------------------------#
- self.score = nn.Linear(2048, n_class)
- #-----------------------------------#
- # 权值初始化
- #-----------------------------------#
- normal_init(self.cls_loc, 0, 0.001)
- normal_init(self.score, 0, 0.01)
- self.roi = RoIPool((roi_size, roi_size), spatial_scale)
- def forward(self, x, rois, roi_indices, img_size):
- n, _, _, _ = x.shape
- if x.is_cuda:
- roi_indices = roi_indices.cuda()
- rois = rois.cuda()
- rois = torch.flatten(rois, 0, 1)
- roi_indices = torch.flatten(roi_indices, 0, 1)
-
- rois_feature_map = torch.zeros_like(rois)
- rois_feature_map[:, [0, 2]] = rois[:, [0, 2]] / img_size[1] * x.size()[3]
- rois_feature_map[:, [1, 3]] = rois[:, [1, 3]] / img_size[0] * x.size()[2]
- indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim = 1)
- #-----------------------------------#
- # 利用建议框对公用特征层进行截取
- #-----------------------------------#
- pool = self.roi(x, indices_and_rois)
- #-----------------------------------#
- # 利用classifier网络进行特征提取
- #-----------------------------------#
- fc7 = self.classifier(pool)
- #--------------------------------------------------------------#
- # 当输入为一张图片的时候,这里获得的f7的shape为[300, 2048]
- #--------------------------------------------------------------#
- fc7 = fc7.view(fc7.size(0), -1)
- roi_cls_locs = self.cls_loc(fc7)
- roi_scores = self.score(fc7)
- roi_cls_locs = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
- roi_scores = roi_scores.view(n, -1, roi_scores.size(1))
- return roi_cls_locs, roi_scores
- def normal_init(m, mean, stddev, truncated = False):
- if truncated:
- m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
- else:
- m.weight.data.normal_(mean, stddev)
- m.bias.data.zero_()
|