resnet.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class CommonBlock(nn.Module):
  5. """ Standard residual block without downsampling. """
  6. def __init__(self, in_channel, out_channel, stride=1):
  7. super(CommonBlock, self).__init__()
  8. self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
  9. self.bn1 = nn.BatchNorm2d(out_channel)
  10. self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
  11. self.bn2 = nn.BatchNorm2d(out_channel)
  12. def forward(self, x):
  13. identity = x
  14. x = F.relu(self.bn1(self.conv1(x)), inplace=True)
  15. x = self.bn2(self.conv2(x))
  16. x += identity
  17. return F.relu(x, inplace=True)
  18. class SpecialBlock(nn.Module):
  19. """ Residual block with downsampling and channel size increase. """
  20. def __init__(self, in_channel, out_channel, stride):
  21. super(SpecialBlock, self).__init__()
  22. self.change_channel = nn.Sequential(
  23. nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False),
  24. nn.BatchNorm2d(out_channel)
  25. )
  26. self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
  27. self.bn1 = nn.BatchNorm2d(out_channel)
  28. self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
  29. self.bn2 = nn.BatchNorm2d(out_channel)
  30. def forward(self, x):
  31. identity = self.change_channel(x)
  32. x = F.relu(self.bn1(self.conv1(x)), inplace=True)
  33. x = self.bn2(self.conv2(x))
  34. x += identity
  35. return F.relu(x, inplace=True)
  36. class ResNet18(nn.Module):
  37. def __init__(self, input_channels, num_classes=10):
  38. super(ResNet18, self).__init__()
  39. self.prepare = nn.Sequential(
  40. nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
  41. nn.BatchNorm2d(64),
  42. nn.ReLU(inplace=True),
  43. nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  44. )
  45. self.layer1 = nn.Sequential(CommonBlock(64, 64), CommonBlock(64, 64))
  46. self.layer2 = nn.Sequential(SpecialBlock(64, 128, 2), CommonBlock(128, 128))
  47. self.layer3 = nn.Sequential(SpecialBlock(128, 256, 2), CommonBlock(256, 256))
  48. self.layer4 = nn.Sequential(SpecialBlock(256, 512, 2), CommonBlock(512, 512))
  49. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  50. self.fc = nn.Linear(512, num_classes)
  51. def forward(self, x):
  52. x = self.prepare(x)
  53. x = self.layer1(x)
  54. x = self.layer2(x)
  55. x = self.layer3(x)
  56. x = self.layer4(x)
  57. x = self.avgpool(x)
  58. x = x.view(x.size(0), -1)
  59. x = self.fc(x)
  60. return x
  61. if __name__ == '__main__':
  62. import argparse
  63. parser = argparse.ArgumentParser(description='Resnet Implementation')
  64. parser.add_argument('--input_channels', default=3, type=int)
  65. parser.add_argument('--output_num', default=10, type=int)
  66. args = parser.parse_args()
  67. model = ResNet18(args.input_channels, args.output_num)
  68. tensor = torch.rand(1, args.input_channels, 224, 224)
  69. pred = model(tensor)
  70. print(model)
  71. print("Predictions shape:", pred.shape)