1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- import torch
- import torch.nn as nn
- class Alexnet(nn.Module):
- def __init__(self, input_channels, output_num, input_size):
- super().__init__()
-
- self.features = nn.Sequential(
- nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1),
- nn.BatchNorm2d(64), # 批量归一化层
- nn.MaxPool2d(kernel_size=2),
- nn.ReLU(inplace=True),
-
- nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
- nn.BatchNorm2d(192), # 批量归一化层
- nn.MaxPool2d(kernel_size=2),
- nn.ReLU(inplace=True),
-
- nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, padding=1),
- nn.BatchNorm2d(384), # 批量归一化层
- nn.ReLU(inplace=True),
-
- nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
- nn.BatchNorm2d(256), # 批量归一化层
- nn.ReLU(inplace=True),
-
- nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
- nn.BatchNorm2d(256), # 批量归一化层
- nn.MaxPool2d(kernel_size=2),
- nn.ReLU(inplace=True),
- )
-
- self.input_size = input_size
- self._init_classifier(output_num)
-
- def _init_classifier(self, output_num):
- with torch.no_grad():
- # Forward a dummy input through the feature extractor part of the network
- dummy_input = torch.zeros(1, 3, self.input_size, self.input_size)
- features_size = self.features(dummy_input).numel()
- self.classifier = nn.Sequential(
- nn.Dropout(0.5),
- nn.Linear(features_size, 1000),
- nn.ReLU(inplace=True),
-
- nn.Dropout(0.5),
- nn.Linear(1000, 256),
- nn.ReLU(inplace=True),
-
- nn.Linear(256, output_num)
- )
-
- def forward(self, x):
- x = self.features(x)
- x = x.reshape(x.size(0), -1)
- x = self.classifier(x)
- return x
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='AlexNet Implementation')
- parser.add_argument('--input_channels', default=3, type=int)
- parser.add_argument('--output_num', default=10, type=int)
- parser.add_argument('--input_size', default=32, type=int)
- args = parser.parse_args()
-
- model = Alexnet(args.input_channels, args.output_num, args.input_size)
- tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
- pred = model(tensor)
-
- print(model)
- print("Predictions shape:", pred.shape)
|