GoogleNet.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Inception(nn.Module):
  5. def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
  6. super(Inception, self).__init__()
  7. # 1x1 conv branch
  8. self.b1 = nn.Sequential(
  9. nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
  10. nn.BatchNorm2d(kernel_1_x),
  11. nn.ReLU(True),
  12. )
  13. # 1x1 conv -> 3x3 conv branch
  14. self.b2 = nn.Sequential(
  15. nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
  16. nn.BatchNorm2d(kernel_3_in),
  17. nn.ReLU(True),
  18. nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
  19. nn.BatchNorm2d(kernel_3_x),
  20. nn.ReLU(True),
  21. )
  22. # 1x1 conv -> 5x5 conv branch
  23. self.b3 = nn.Sequential(
  24. nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
  25. nn.BatchNorm2d(kernel_5_in),
  26. nn.ReLU(True),
  27. nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
  28. nn.BatchNorm2d(kernel_5_x),
  29. nn.ReLU(True),
  30. nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
  31. nn.BatchNorm2d(kernel_5_x),
  32. nn.ReLU(True),
  33. )
  34. # 3x3 pool -> 1x1 conv branch
  35. self.b4 = nn.Sequential(
  36. nn.MaxPool2d(3, stride=1, padding=1),
  37. nn.Conv2d(in_planes, pool_planes, kernel_size=1),
  38. nn.BatchNorm2d(pool_planes),
  39. nn.ReLU(True),
  40. )
  41. def forward(self, x):
  42. y1 = self.b1(x)
  43. y2 = self.b2(x)
  44. y3 = self.b3(x)
  45. y4 = self.b4(x)
  46. return torch.cat([y1, y2, y3, y4], 1)
  47. class GoogLeNet(nn.Module):
  48. def __init__(self, input_channels, output_num):
  49. super(GoogLeNet, self).__init__()
  50. self.pre_layers = nn.Sequential(
  51. nn.Conv2d(input_channels, 192, kernel_size=3, padding=1),
  52. nn.BatchNorm2d(192),
  53. nn.ReLU(True),
  54. )
  55. self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
  56. self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
  57. self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
  58. self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
  59. self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
  60. self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
  61. self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
  62. self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
  63. self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
  64. self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
  65. self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive pooling
  66. self.linear = nn.Linear(1024, output_num)
  67. def forward(self, x):
  68. x = self.pre_layers(x)
  69. x = self.a3(x)
  70. x = self.b3(x)
  71. x = self.max_pool(x)
  72. x = self.a4(x)
  73. x = self.b4(x)
  74. x = self.c4(x)
  75. x = self.d4(x)
  76. x = self.e4(x)
  77. x = self.max_pool(x)
  78. x = self.a5(x)
  79. x = self.b5(x)
  80. x = self.avgpool(x)
  81. x = x.view(x.size(0), -1)
  82. x = self.linear(x)
  83. return x
  84. if __name__ == '__main__':
  85. import argparse
  86. parser = argparse.ArgumentParser(description='GoogLeNet Implementation')
  87. parser.add_argument('--input_channels', default=3, type=int)
  88. parser.add_argument('--output_num', default=10, type=int)
  89. args = parser.parse_args()
  90. model = GoogLeNet(args.input_channels, args.output_num)
  91. tensor = torch.rand(1, args.input_channels, 224, 224) # Example for a larger size
  92. pred = model(tensor)
  93. pred_shape = pred.shape
  94. print(model)
  95. print("Predictions shape:", pred_shape)