vgg16.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import torch
  2. import torch.nn as nn
  3. from torch.hub import load_state_dict_from_url
  4. #--------------------------------------#
  5. # VGG16的结构
  6. #--------------------------------------#
  7. class VGG(nn.Module):
  8. def __init__(self, features, num_classes=1000, init_weights=True):
  9. super(VGG, self).__init__()
  10. self.features = features
  11. #--------------------------------------#
  12. # 平均池化到7x7大小
  13. #--------------------------------------#
  14. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  15. #--------------------------------------#
  16. # 分类部分
  17. #--------------------------------------#
  18. self.classifier = nn.Sequential(
  19. nn.Linear(512 * 7 * 7, 4096),
  20. nn.ReLU(True),
  21. nn.Dropout(),
  22. nn.Linear(4096, 4096),
  23. nn.ReLU(True),
  24. nn.Dropout(),
  25. nn.Linear(4096, num_classes),
  26. )
  27. if init_weights:
  28. self._initialize_weights()
  29. def forward(self, x):
  30. #--------------------------------------#
  31. # 特征提取
  32. #--------------------------------------#
  33. x = self.features(x)
  34. #--------------------------------------#
  35. # 平均池化
  36. #--------------------------------------#
  37. x = self.avgpool(x)
  38. #--------------------------------------#
  39. # 平铺后
  40. #--------------------------------------#
  41. x = torch.flatten(x, 1)
  42. #--------------------------------------#
  43. # 分类部分
  44. #--------------------------------------#
  45. x = self.classifier(x)
  46. return x
  47. def _initialize_weights(self):
  48. for m in self.modules():
  49. if isinstance(m, nn.Conv2d):
  50. nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')
  51. if m.bias is not None:
  52. nn.init.constant_(m.bias, 0)
  53. elif isinstance(m, nn.BatchNorm2d):
  54. nn.init.constant_(m.weight, 1)
  55. nn.init.constant_(m.bias, 0)
  56. elif isinstance(m, nn.Linear):
  57. nn.init.normal_(m.weight, 0, 0.01)
  58. nn.init.constant_(m.bias, 0)
  59. '''
  60. 假设输入图像为(600, 600, 3),随着cfg的循环,特征层变化如下:
  61. 600,600,3 -> 600,600,64 -> 600,600,64 -> 300,300,64 -> 300,300,128 -> 300,300,128 -> 150,150,128 -> 150,150,256 -> 150,150,256 -> 150,150,256
  62. -> 75,75,256 -> 75,75,512 -> 75,75,512 -> 75,75,512 -> 37,37,512 -> 37,37,512 -> 37,37,512 -> 37,37,512
  63. 到cfg结束,我们获得了一个37,37,512的特征层
  64. '''
  65. cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
  66. #--------------------------------------#
  67. # 特征提取部分
  68. #--------------------------------------#
  69. def make_layers(cfg, batch_norm = False):
  70. layers = []
  71. in_channels = 3
  72. for v in cfg:
  73. if v == 'M':
  74. layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
  75. else:
  76. conv2d = nn.Conv2d(in_channels, v, kernel_size = 3, padding = 1)
  77. if batch_norm:
  78. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)]
  79. else:
  80. layers += [conv2d, nn.ReLU(inplace = True)]
  81. in_channels = v
  82. return nn.Sequential(*layers)
  83. def decom_vgg16(pretrained = False):
  84. model = VGG(make_layers(cfg))
  85. if pretrained:
  86. state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir = "./model_data")
  87. model.load_state_dict(state_dict)
  88. #----------------------------------------------------------------------------#
  89. # 获取特征提取部分,最终获得一个37,37,1024的特征层
  90. #----------------------------------------------------------------------------#
  91. features = list(model.features)[:30]
  92. #----------------------------------------------------------------------------#
  93. # 获取分类部分,需要除去Dropout部分
  94. #----------------------------------------------------------------------------#
  95. classifier = list(model.classifier)
  96. del classifier[6]
  97. del classifier[5]
  98. del classifier[2]
  99. features = nn.Sequential(*features)
  100. classifier = nn.Sequential(*classifier)
  101. return features, classifier