import torch import torch.nn as nn class Inception(nn.Module): def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes): super(Inception, self).__init__() # 1x1 conv branch self.b1 = nn.Sequential( nn.Conv2d(in_planes, kernel_1_x, kernel_size=1), nn.BatchNorm2d(kernel_1_x), nn.ReLU(True), ) # 1x1 conv -> 3x3 conv branch self.b2 = nn.Sequential( nn.Conv2d(in_planes, kernel_3_in, kernel_size=1), nn.BatchNorm2d(kernel_3_in), nn.ReLU(True), nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1), nn.BatchNorm2d(kernel_3_x), nn.ReLU(True), ) # 1x1 conv -> 5x5 conv branch self.b3 = nn.Sequential( nn.Conv2d(in_planes, kernel_5_in, kernel_size=1), nn.BatchNorm2d(kernel_5_in), nn.ReLU(True), nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1), nn.BatchNorm2d(kernel_5_x), nn.ReLU(True), nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1), nn.BatchNorm2d(kernel_5_x), nn.ReLU(True), ) # 3x3 pool -> 1x1 conv branch self.b4 = nn.Sequential( nn.MaxPool2d(3, stride=1, padding=1), nn.Conv2d(in_planes, pool_planes, kernel_size=1), nn.BatchNorm2d(pool_planes), nn.ReLU(True), ) def forward(self, x): y1 = self.b1(x) y2 = self.b2(x) y3 = self.b3(x) y4 = self.b4(x) return torch.cat([y1, y2, y3, y4], 1) class GoogLeNet(nn.Module): def __init__(self, input_channels, output_num): super(GoogLeNet, self).__init__() self.pre_layers = nn.Sequential( nn.Conv2d(input_channels, 192, kernel_size=3, padding=1), nn.BatchNorm2d(192), nn.ReLU(True), ) self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive pooling self.linear = nn.Linear(1024, output_num) def forward(self, x): x = self.pre_layers(x) x = self.a3(x) x = self.b3(x) x = self.max_pool(x) x = self.a4(x) x = self.b4(x) x = self.c4(x) x = self.d4(x) x = self.e4(x) x = self.max_pool(x) x = self.a5(x) x = self.b5(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.linear(x) return x if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='GoogLeNet Implementation') parser.add_argument('--input_channels', default=3, type=int) parser.add_argument('--output_num', default=10, type=int) args = parser.parse_args() model = GoogLeNet(args.input_channels, args.output_num) tensor = torch.rand(1, args.input_channels, 224, 224) # Example for a larger size pred = model(tensor) pred_shape = pred.shape print(model) print("Predictions shape:", pred_shape)