lenet.py 884 B

1234567891011121314151617181920212223242526
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class LeNet(nn.Module):
  5. def __init__(self):
  6. super(LeNet, self).__init__()
  7. self.conv1 = nn.Conv2d(3, 16, 5)
  8. self.pool1 = nn.MaxPool2d(2, 2)
  9. self.conv2 = nn.Conv2d(16, 32, 5)
  10. self.pool2 = nn.MaxPool2d(2, 2)
  11. self.fc1 = nn.Linear(32 * 5 * 5, 120)
  12. self.fc2 = nn.Linear(120, 84)
  13. self.fc3 = nn.Linear(84, 10)
  14. def forward(self, x):
  15. x = F.relu(self.conv1(x)) # input(3,32,32) output(16,28,28)
  16. x = self.pool1(x) # output(16,14,14)
  17. x = F.relu(self.conv2(x)) # output(32,10.10)
  18. x = self.pool2(x) # output(32,5,5)
  19. x = x.view(-1, 32 * 5 * 5) # output(5*5*32)
  20. x = F.relu(self.fc1(x)) # output(120)
  21. x = F.relu(self.fc2(x)) # output(84)
  22. x = self.fc3(x) # output(10)
  23. return x