1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- _cfg = {
- 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
- 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
- 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
- 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
- }
- def _make_layers(cfg, input_size):
- layers = []
- in_channels = 3
- for layer_cfg in cfg:
- if layer_cfg == 'M':
- layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
- input_size = input_size // 2
- else:
- layers.append(nn.Conv2d(in_channels=in_channels, out_channels=layer_cfg, kernel_size=3, stride=1, padding=1))
- layers.append(nn.BatchNorm2d(num_features=layer_cfg))
- layers.append(nn.ReLU(inplace=True))
- in_channels = layer_cfg
- return nn.Sequential(*layers), input_size
- class VGG(nn.Module):
- def __init__(self, name, input_size=32, num_classes=10):
- super(VGG, self).__init__()
- cfg = _cfg[name]
- self.features, final_size = _make_layers(cfg, input_size)
- self.fc = nn.Linear(512 * final_size * final_size, num_classes)
-
- def forward(self, x):
- x = self.features(x)
- x = x.view(x.size(0), -1)
- x = self.fc(x)
- return x
- def VGG11():
- return VGG('VGG11')
- def VGG13():
- return VGG('VGG13')
- def VGG16():
- return VGG('VGG16')
- def VGG19():
- return VGG('VGG19')
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='VGG Model Test')
- parser.add_argument('--input_channels', default=3, type=int)
- parser.add_argument('--output_num', default=10, type=int)
- parser.add_argument('--input_size', default=32, type=int)
- args = parser.parse_args()
-
- model = VGG19() # Changed to use VGG19
- tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
- pred = model(tensor)
-
- print(model)
- print("Predictions shape:", pred.shape)
|