import torch import torch.nn as nn import torch.nn.functional as F class CommonBlock(nn.Module): """ Standard residual block without downsampling. """ def __init__(self, in_channel, out_channel, stride=1): super(CommonBlock, self).__init__() self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channel) self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channel) def forward(self, x): identity = x x = F.relu(self.bn1(self.conv1(x)), inplace=True) x = self.bn2(self.conv2(x)) x += identity return F.relu(x, inplace=True) class SpecialBlock(nn.Module): """ Residual block with downsampling and channel size increase. """ def __init__(self, in_channel, out_channel, stride): super(SpecialBlock, self).__init__() self.change_channel = nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False), nn.BatchNorm2d(out_channel) ) self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channel) self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channel) def forward(self, x): identity = self.change_channel(x) x = F.relu(self.bn1(self.conv1(x)), inplace=True) x = self.bn2(self.conv2(x)) x += identity return F.relu(x, inplace=True) class ResNet18(nn.Module): def __init__(self, input_channels, num_classes=10): super(ResNet18, self).__init__() self.prepare = nn.Sequential( nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) ) self.layer1 = nn.Sequential(CommonBlock(64, 64), CommonBlock(64, 64)) self.layer2 = nn.Sequential(SpecialBlock(64, 128, 2), CommonBlock(128, 128)) self.layer3 = nn.Sequential(SpecialBlock(128, 256, 2), CommonBlock(256, 256)) self.layer4 = nn.Sequential(SpecialBlock(256, 512, 2), CommonBlock(512, 512)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512, num_classes) def forward(self, x): x = self.prepare(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Resnet Implementation') parser.add_argument('--input_channels', default=3, type=int) parser.add_argument('--output_num', default=10, type=int) args = parser.parse_args() model = ResNet18(args.input_channels, args.output_num) tensor = torch.rand(1, args.input_channels, 224, 224) pred = model(tensor) print(model) print("Predictions shape:", pred.shape)