import torch.nn as nn # 定义VGG16的网络结构 class VGGNet(nn.Module): def __init__(self, num_classes=10): super(VGGNet, self).__init__() self.features = self._make_layers() self.classifier = nn.Sequential( nn.Linear(512 * 1 * 1, 512), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(512, 512), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(512, num_classes) ) def _make_layers(self): layers = [] in_channels = 3 cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x