mobilenetv2.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. '''MobileNetV2 in PyTorch.
  2. See the paper "Inverted Residuals and Linear Bottlenecks:
  3. Mobile Networks for Classification, Detection and Segmentation" for more details.
  4. '''
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. class Block(nn.Module):
  9. '''expand + depthwise + pointwise'''
  10. def __init__(self, in_planes, out_planes, expansion, stride):
  11. super(Block, self).__init__()
  12. self.stride = stride
  13. planes = expansion * in_planes
  14. self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
  15. self.bn1 = nn.BatchNorm2d(planes)
  16. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
  17. self.bn2 = nn.BatchNorm2d(planes)
  18. self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
  19. self.bn3 = nn.BatchNorm2d(out_planes)
  20. self.shortcut = nn.Sequential()
  21. if stride == 1 and in_planes != out_planes:
  22. self.shortcut = nn.Sequential(
  23. nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
  24. nn.BatchNorm2d(out_planes),
  25. )
  26. def forward(self, x):
  27. out = F.relu(self.bn1(self.conv1(x)))
  28. out = F.relu(self.bn2(self.conv2(out)))
  29. out = self.bn3(self.conv3(out))
  30. out = out + self.shortcut(x) if self.stride==1 else out
  31. return out
  32. class MobileNetV2(nn.Module):
  33. # (expansion, out_planes, num_blocks, stride)
  34. cfg = [(1, 16, 1, 1),
  35. (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
  36. (6, 32, 3, 2),
  37. (6, 64, 4, 2),
  38. (6, 96, 3, 1),
  39. (6, 160, 3, 2),
  40. (6, 320, 1, 1)]
  41. def __init__(self, input_channels, output_num):
  42. super(MobileNetV2, self).__init__()
  43. # NOTE: change conv1 stride 2 -> 1 for CIFAR10
  44. self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1, bias=False)
  45. self.bn1 = nn.BatchNorm2d(32)
  46. self.layers = self._make_layers(in_planes=32)
  47. self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
  48. self.bn2 = nn.BatchNorm2d(1280)
  49. self.linear = nn.Linear(1280, output_num)
  50. def _make_layers(self, in_planes):
  51. layers = []
  52. for expansion, out_planes, num_blocks, stride in self.cfg:
  53. strides = [stride] + [1]*(num_blocks-1)
  54. for stride in strides:
  55. layers.append(Block(in_planes, out_planes, expansion, stride))
  56. in_planes = out_planes
  57. return nn.Sequential(*layers)
  58. def forward(self, x):
  59. out = F.relu(self.bn1(self.conv1(x)))
  60. out = self.layers(out)
  61. out = F.relu(self.bn2(self.conv2(out)))
  62. # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
  63. out = F.avg_pool2d(out, 4)
  64. out = out.view(out.size(0), -1)
  65. out = self.linear(out)
  66. return out
  67. if __name__ == '__main__':
  68. import argparse
  69. parser = argparse.ArgumentParser(description='MobileNetV2 Implementation')
  70. parser.add_argument('--input_channels', default=3, type=int)
  71. parser.add_argument('--output_num', default=10, type=int)
  72. # parser.add_argument('--input_size', default=32, type=int)
  73. args = parser.parse_args()
  74. model = MobileNetV2(args.input_channels, args.output_num)
  75. tensor = torch.rand(1, args.input_channels, 32, 32)
  76. pred = model(tensor)
  77. print(model)
  78. print("Predictions shape:", pred.shape)