import torch import torch.nn as nn class LeNet(nn.Module): def __init__(self, input_channels, output_num, input_size): super(LeNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(input_channels, 16, 5), nn.MaxPool2d(2, 2), nn.Conv2d(16, 32, 5), nn.MaxPool2d(2, 2) ) self.input_size = input_size self.input_channels = input_channels self._init_classifier(output_num) def _init_classifier(self, output_num): with torch.no_grad(): # Forward a dummy input through the feature extractor part of the network dummy_input = torch.zeros(1, self.input_channels, self.input_size, self.input_size) features_size = self.features(dummy_input).numel() self.classifier = nn.Sequential( nn.Linear(features_size, 120), nn.Linear(120, 84), nn.Linear(84, output_num) ) def forward(self, x): x = self.features(x) x = x.reshape(x.size(0), -1) x = self.classifier(x) return x def get_encode_layers(self): """ 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层 """ conv_list = [] for module in self.modules(): if isinstance(module, nn.Conv2d): conv_list.append(module) return conv_list