import torch import torch.nn as nn import torch.nn.functional as F class Alexnet(nn.Module): def __init__(self, input_channels, output_num, input_size): super().__init__() self.features = nn.Sequential( nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), # 批量归一化层 nn.MaxPool2d(kernel_size=2), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1), nn.BatchNorm2d(192), # 批量归一化层 nn.MaxPool2d(kernel_size=2), nn.ReLU(inplace=True), nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, padding=1), nn.BatchNorm2d(384), # 批量归一化层 nn.ReLU(inplace=True), nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1), nn.BatchNorm2d(256), # 批量归一化层 nn.ReLU(inplace=True), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), nn.BatchNorm2d(256), # 批量归一化层 nn.MaxPool2d(kernel_size=2), nn.ReLU(inplace=True), ) self.input_size = input_size self._init_classifier(output_num) def _init_classifier(self, output_num): with torch.no_grad(): # Forward a dummy input through the feature extractor part of the network dummy_input = torch.zeros(1, 3, self.input_size, self.input_size) features_size = self.features(dummy_input).numel() self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(features_size, 1000), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(1000, 256), nn.ReLU(inplace=True), nn.Linear(256, output_num) ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='AlexNet Implementation') parser.add_argument('--input_channels', default=3, type=int) parser.add_argument('--output_num', default=10, type=int) parser.add_argument('--input_size', default=32, type=int) args = parser.parse_args() model = Alexnet(args.input_channels, args.output_num, args.input_size) tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size) pred = model(tensor) print(model) print("Predictions shape:", pred.shape)