VGG19.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import torch
  2. import torch.nn as nn
  3. _cfg = {
  4. 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  5. 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  6. 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  7. 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  8. }
  9. def _make_layers(cfg, input_size):
  10. layers = []
  11. in_channels = 3
  12. for layer_cfg in cfg:
  13. if layer_cfg == 'M':
  14. layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
  15. input_size = input_size // 2
  16. else:
  17. layers.append(nn.Conv2d(in_channels=in_channels, out_channels=layer_cfg, kernel_size=3, stride=1, padding=1))
  18. layers.append(nn.BatchNorm2d(num_features=layer_cfg))
  19. layers.append(nn.ReLU(inplace=True))
  20. in_channels = layer_cfg
  21. return nn.Sequential(*layers), input_size
  22. class VGG(nn.Module):
  23. def __init__(self, name, input_size=32, num_classes=10):
  24. super(VGG, self).__init__()
  25. cfg = _cfg[name]
  26. self.features, final_size = _make_layers(cfg, input_size)
  27. self.fc = nn.Linear(512 * final_size * final_size, num_classes)
  28. def forward(self, x):
  29. x = self.features(x)
  30. x = x.view(x.size(0), -1)
  31. x = self.fc(x)
  32. return x
  33. def VGG11():
  34. return VGG('VGG11')
  35. def VGG13():
  36. return VGG('VGG13')
  37. def VGG16():
  38. return VGG('VGG16')
  39. def VGG19():
  40. return VGG('VGG19')
  41. if __name__ == '__main__':
  42. import argparse
  43. parser = argparse.ArgumentParser(description='VGG Model Test')
  44. parser.add_argument('--input_channels', default=3, type=int)
  45. parser.add_argument('--output_num', default=10, type=int)
  46. parser.add_argument('--input_size', default=32, type=int)
  47. args = parser.parse_args()
  48. model = VGG19() # Changed to use VGG19
  49. tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
  50. pred = model(tensor)
  51. print(model)
  52. print("Predictions shape:", pred.shape)