AlexNet.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import mindspore
  2. import mindspore.nn as nn
  3. import mindspore.ops.operations as P
  4. class AlexNet(nn.Cell):
  5. def __init__(self, input_channels, output_num, input_size):
  6. super().__init__()
  7. self.features = nn.SequentialCell([
  8. nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, pad_mode='pad', padding=1,
  9. has_bias=True),
  10. nn.BatchNorm2d(num_features=64, momentum=0.9), # 批量归一化层
  11. nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid'),
  12. nn.ReLU(),
  13. nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, pad_mode='pad', padding=1, has_bias=True),
  14. nn.BatchNorm2d(num_features=192, momentum=0.9), # 批量归一化层
  15. nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid'),
  16. nn.ReLU(),
  17. nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, pad_mode='pad', padding=1, has_bias=True),
  18. nn.BatchNorm2d(num_features=384, momentum=0.9), # 批量归一化层
  19. nn.ReLU(),
  20. nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, pad_mode='pad', padding=1, has_bias=True),
  21. nn.BatchNorm2d(num_features=256, momentum=0.9), # 批量归一化层
  22. nn.ReLU(),
  23. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, pad_mode='pad', padding=1, has_bias=True),
  24. nn.BatchNorm2d(num_features=256, momentum=0.9), # 批量归一化层
  25. nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid'),
  26. nn.ReLU(),
  27. ])
  28. self.input_size = input_size
  29. self._init_classifier(output_num)
  30. def _init_classifier(self, output_num):
  31. # Forward a dummy input through the feature extractor part of the network
  32. dummy_input = mindspore.ops.zeros((1, 3, self.input_size, self.input_size))
  33. features_size = self.features(dummy_input).numel()
  34. self.classifier = nn.SequentialCell([
  35. nn.Dropout(p=0.5),
  36. nn.Dense(in_channels=features_size, out_channels=1000),
  37. nn.ReLU(),
  38. nn.Dropout(p=0.5),
  39. nn.Dense(in_channels=1000, out_channels=256),
  40. nn.ReLU(),
  41. nn.Dense(in_channels=256, out_channels=output_num)
  42. ])
  43. def construct(self, x):
  44. x = self.features(x)
  45. x = P.Reshape()(x, (P.Shape()(x)[0], -1,))
  46. x = self.classifier(x)
  47. return x
  48. if __name__ == '__main__':
  49. import argparse
  50. parser = argparse.ArgumentParser(description='AlexNet Implementation')
  51. parser.add_argument('--input_channels', default=3, type=int)
  52. parser.add_argument('--output_num', default=10, type=int)
  53. parser.add_argument('--input_size', default=32, type=int)
  54. args = parser.parse_args()
  55. model = AlexNet(args.input_channels, args.output_num, args.input_size)
  56. tensor = mindspore.ops.rand((1, args.input_channels, args.input_size, args.input_size))
  57. pred = model(tensor)
  58. print(model)
  59. print("Predictions shape:", pred.shape)