GoogleNet.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import torch
  2. import torch.nn as nn
  3. class Inception(nn.Module):
  4. def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
  5. super(Inception, self).__init__()
  6. # 1x1 conv branch
  7. self.b1 = nn.Sequential(
  8. nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
  9. nn.BatchNorm2d(kernel_1_x),
  10. nn.ReLU(True),
  11. )
  12. # 1x1 conv -> 3x3 conv branch
  13. self.b2 = nn.Sequential(
  14. nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
  15. nn.BatchNorm2d(kernel_3_in),
  16. nn.ReLU(True),
  17. nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
  18. nn.BatchNorm2d(kernel_3_x),
  19. nn.ReLU(True),
  20. )
  21. # 1x1 conv -> 5x5 conv branch
  22. self.b3 = nn.Sequential(
  23. nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
  24. nn.BatchNorm2d(kernel_5_in),
  25. nn.ReLU(True),
  26. nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
  27. nn.BatchNorm2d(kernel_5_x),
  28. nn.ReLU(True),
  29. nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
  30. nn.BatchNorm2d(kernel_5_x),
  31. nn.ReLU(True),
  32. )
  33. # 3x3 pool -> 1x1 conv branch
  34. self.b4 = nn.Sequential(
  35. nn.MaxPool2d(3, stride=1, padding=1),
  36. nn.Conv2d(in_planes, pool_planes, kernel_size=1),
  37. nn.BatchNorm2d(pool_planes),
  38. nn.ReLU(True),
  39. )
  40. def forward(self, x):
  41. y1 = self.b1(x)
  42. y2 = self.b2(x)
  43. y3 = self.b3(x)
  44. y4 = self.b4(x)
  45. return torch.cat([y1, y2, y3, y4], 1)
  46. class GoogLeNet(nn.Module):
  47. def __init__(self, input_channels, output_num):
  48. super(GoogLeNet, self).__init__()
  49. self.pre_layers = nn.Sequential(
  50. nn.Conv2d(input_channels, 192, kernel_size=3, padding=1),
  51. nn.BatchNorm2d(192),
  52. nn.ReLU(True),
  53. )
  54. self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
  55. self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
  56. self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
  57. self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
  58. self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
  59. self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
  60. self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
  61. self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
  62. self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
  63. self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
  64. self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive pooling
  65. self.linear = nn.Linear(1024, output_num)
  66. def forward(self, x):
  67. x = self.pre_layers(x)
  68. x = self.a3(x)
  69. x = self.b3(x)
  70. x = self.max_pool(x)
  71. x = self.a4(x)
  72. x = self.b4(x)
  73. x = self.c4(x)
  74. x = self.d4(x)
  75. x = self.e4(x)
  76. x = self.max_pool(x)
  77. x = self.a5(x)
  78. x = self.b5(x)
  79. x = self.avgpool(x)
  80. x = x.view(x.size(0), -1)
  81. x = self.linear(x)
  82. return x
  83. if __name__ == '__main__':
  84. import argparse
  85. parser = argparse.ArgumentParser(description='GoogLeNet Implementation')
  86. parser.add_argument('--input_channels', default=3, type=int)
  87. parser.add_argument('--output_num', default=10, type=int)
  88. args = parser.parse_args()
  89. model = GoogLeNet(args.input_channels, args.output_num)
  90. tensor = torch.rand(1, args.input_channels, 224, 224) # Example for a larger size
  91. pred = model(tensor)
  92. pred_shape = pred.shape
  93. print(model)
  94. print("Predictions shape:", pred_shape)