1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class CommonBlock(nn.Module):
- """ Standard residual block without downsampling. """
- def __init__(self, in_channel, out_channel, stride=1):
- super(CommonBlock, self).__init__()
- self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(out_channel)
- self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(out_channel)
- def forward(self, x):
- identity = x
- x = F.relu(self.bn1(self.conv1(x)), inplace=True)
- x = self.bn2(self.conv2(x))
- x += identity
- return F.relu(x, inplace=True)
- class SpecialBlock(nn.Module):
- """ Residual block with downsampling and channel size increase. """
- def __init__(self, in_channel, out_channel, stride):
- super(SpecialBlock, self).__init__()
- self.change_channel = nn.Sequential(
- nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False),
- nn.BatchNorm2d(out_channel)
- )
- self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(out_channel)
- self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(out_channel)
- def forward(self, x):
- identity = self.change_channel(x)
- x = F.relu(self.bn1(self.conv1(x)), inplace=True)
- x = self.bn2(self.conv2(x))
- x += identity
- return F.relu(x, inplace=True)
- class ResNet18(nn.Module):
- def __init__(self, input_channels, num_classes=10):
- super(ResNet18, self).__init__()
- self.prepare = nn.Sequential(
- nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- )
- self.layer1 = nn.Sequential(CommonBlock(64, 64), CommonBlock(64, 64))
- self.layer2 = nn.Sequential(SpecialBlock(64, 128, 2), CommonBlock(128, 128))
- self.layer3 = nn.Sequential(SpecialBlock(128, 256, 2), CommonBlock(256, 256))
- self.layer4 = nn.Sequential(SpecialBlock(256, 512, 2), CommonBlock(512, 512))
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(512, num_classes)
- def forward(self, x):
- x = self.prepare(x)
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
- x = self.avgpool(x)
- x = x.view(x.size(0), -1)
- x = self.fc(x)
- return x
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='Resnet Implementation')
- parser.add_argument('--input_channels', default=3, type=int)
- parser.add_argument('--output_num', default=10, type=int)
- args = parser.parse_args()
-
- model = ResNet18(args.input_channels, args.output_num)
- tensor = torch.rand(1, args.input_channels, 224, 224)
- pred = model(tensor)
-
- print(model)
- print("Predictions shape:", pred.shape)
|