VGG19.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. _cfg = {
  5. 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  6. 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  7. 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  8. 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  9. }
  10. def _make_layers(cfg, input_size):
  11. layers = []
  12. in_channels = 3
  13. for layer_cfg in cfg:
  14. if layer_cfg == 'M':
  15. layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
  16. input_size = input_size // 2
  17. else:
  18. layers.append(nn.Conv2d(in_channels=in_channels, out_channels=layer_cfg, kernel_size=3, stride=1, padding=1))
  19. layers.append(nn.BatchNorm2d(num_features=layer_cfg))
  20. layers.append(nn.ReLU(inplace=True))
  21. in_channels = layer_cfg
  22. return nn.Sequential(*layers), input_size
  23. class VGG(nn.Module):
  24. def __init__(self, name, input_size=32, num_classes=10):
  25. super(VGG, self).__init__()
  26. cfg = _cfg[name]
  27. self.features, final_size = _make_layers(cfg, input_size)
  28. self.fc = nn.Linear(512 * final_size * final_size, num_classes)
  29. def forward(self, x):
  30. x = self.features(x)
  31. x = x.view(x.size(0), -1)
  32. x = self.fc(x)
  33. return x
  34. def VGG11():
  35. return VGG('VGG11')
  36. def VGG13():
  37. return VGG('VGG13')
  38. def VGG16():
  39. return VGG('VGG16')
  40. def VGG19():
  41. return VGG('VGG19')
  42. if __name__ == '__main__':
  43. import argparse
  44. parser = argparse.ArgumentParser(description='VGG Model Test')
  45. parser.add_argument('--input_channels', default=3, type=int)
  46. parser.add_argument('--output_num', default=10, type=int)
  47. parser.add_argument('--input_size', default=32, type=int)
  48. args = parser.parse_args()
  49. model = VGG19() # Changed to use VGG19
  50. tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
  51. pred = model(tensor)
  52. print(model)
  53. print("Predictions shape:", pred.shape)