LeNet.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import torch
  2. import torch.nn as nn
  3. class LeNet(nn.Module):
  4. def __init__(self, input_channels, output_num, input_size):
  5. super(LeNet, self).__init__()
  6. self.features = nn.Sequential(
  7. nn.Conv2d(input_channels, 16, 5),
  8. nn.MaxPool2d(2, 2),
  9. nn.Conv2d(16, 32, 5),
  10. nn.MaxPool2d(2, 2)
  11. )
  12. self.input_size = input_size
  13. self.input_channels = input_channels
  14. self._init_classifier(output_num)
  15. def _init_classifier(self, output_num):
  16. with torch.no_grad():
  17. # Forward a dummy input through the feature extractor part of the network
  18. dummy_input = torch.zeros(1, self.input_channels, self.input_size, self.input_size)
  19. features_size = self.features(dummy_input).numel()
  20. self.classifier = nn.Sequential(
  21. nn.Linear(features_size, 120),
  22. nn.Linear(120, 84),
  23. nn.Linear(84, output_num)
  24. )
  25. def forward(self, x):
  26. x = self.features(x)
  27. x = x.reshape(x.size(0), -1)
  28. x = self.classifier(x)
  29. return x
  30. def get_encode_layers(self):
  31. """
  32. 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
  33. """
  34. conv_list = []
  35. for module in self.modules():
  36. if isinstance(module, nn.Conv2d):
  37. conv_list.append(module)
  38. return conv_list