12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import torch
- import torch.nn as nn
- class LeNet(nn.Module):
- def __init__(self, input_channels, output_num, input_size):
- super(LeNet, self).__init__()
- self.features = nn.Sequential(
- nn.Conv2d(input_channels, 16, 5),
- nn.MaxPool2d(2, 2),
- nn.Conv2d(16, 32, 5),
- nn.MaxPool2d(2, 2)
- )
- self.input_size = input_size
- self.input_channels = input_channels
- self._init_classifier(output_num)
- def _init_classifier(self, output_num):
- with torch.no_grad():
- # Forward a dummy input through the feature extractor part of the network
- dummy_input = torch.zeros(1, self.input_channels, self.input_size, self.input_size)
- features_size = self.features(dummy_input).numel()
- self.classifier = nn.Sequential(
- nn.Linear(features_size, 120),
- nn.Linear(120, 84),
- nn.Linear(84, output_num)
- )
- def forward(self, x):
- x = self.features(x)
- x = x.reshape(x.size(0), -1)
- x = self.classifier(x)
- return x
- def get_encode_layers(self):
- """
- 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
- """
- conv_list = []
- for module in self.modules():
- if isinstance(module, nn.Conv2d):
- conv_list.append(module)
- return conv_list
|