import torch import torch.nn as nn class LeNet(nn.Module): def __init__(self, input_channels, output_num, input_size): super(LeNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(input_channels, 16, 5), nn.MaxPool2d(2, 2), nn.Conv2d(16, 32, 5), nn.MaxPool2d(2, 2) ) self.input_size = input_size self.input_channels = input_channels self._init_classifier(output_num) def _init_classifier(self, output_num): with torch.no_grad(): # Forward a dummy input through the feature extractor part of the network dummy_input = torch.zeros(1, self.input_channels, self.input_size, self.input_size) features_size = self.features(dummy_input).numel() self.classifier = nn.Sequential( nn.Linear(features_size, 120), nn.Linear(120, 84), nn.Linear(84, output_num) ) def forward(self, x): x = self.features(x) x = x.reshape(x.size(0), -1) x = self.classifier(x) return x if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='LeNet Implementation') parser.add_argument('--input_channels', default=3, type=int) parser.add_argument('--output_num', default=10, type=int) parser.add_argument('--input_size', default=32, type=int) args = parser.parse_args() model = LeNet(args.input_channels, args.output_num, args.input_size) tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size) pred = model(tensor) print(model) print("Predictions shape:", pred.shape)