123456789101112131415161718192021222324252627282930313233343536 |
- import torch.nn as nn
- # 定义VGG16的网络结构
- class VGGNet(nn.Module):
- def __init__(self, num_classes=10):
- super(VGGNet, self).__init__()
- self.features = self._make_layers()
- self.classifier = nn.Sequential(
- nn.Linear(512 * 1 * 1, 512),
- nn.ReLU(inplace=True),
- nn.Dropout(),
- nn.Linear(512, 512),
- nn.ReLU(inplace=True),
- nn.Dropout(),
- nn.Linear(512, num_classes)
- )
- def _make_layers(self):
- layers = []
- in_channels = 3
- cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
- for v in cfg:
- if v == 'M':
- layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
- else:
- conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
- layers += [conv2d, nn.ReLU(inplace=True)]
- in_channels = v
- return nn.Sequential(*layers)
- def forward(self, x):
- x = self.features(x)
- x = x.view(x.size(0), -1)
- x = self.classifier(x)
- return x
|