vgg16.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import torch.nn as nn
  2. # 定义VGG16的网络结构
  3. class VGGNet(nn.Module):
  4. def __init__(self, num_classes=10):
  5. super(VGGNet, self).__init__()
  6. self.features = self._make_layers()
  7. self.classifier = nn.Sequential(
  8. nn.Linear(512 * 1 * 1, 512),
  9. nn.ReLU(inplace=True),
  10. nn.Dropout(),
  11. nn.Linear(512, 512),
  12. nn.ReLU(inplace=True),
  13. nn.Dropout(),
  14. nn.Linear(512, num_classes)
  15. )
  16. def _make_layers(self):
  17. layers = []
  18. in_channels = 3
  19. cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
  20. for v in cfg:
  21. if v == 'M':
  22. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  23. else:
  24. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  25. layers += [conv2d, nn.ReLU(inplace=True)]
  26. in_channels = v
  27. return nn.Sequential(*layers)
  28. def forward(self, x):
  29. x = self.features(x)
  30. x = x.view(x.size(0), -1)
  31. x = self.classifier(x)
  32. return x