|
@@ -0,0 +1,46 @@
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+
|
|
|
+
|
|
|
+class LeNet(nn.Module):
|
|
|
+ def __init__(self, input_channels, output_num, input_size):
|
|
|
+ super(LeNet, self).__init__()
|
|
|
+
|
|
|
+ self.features = nn.Sequential(
|
|
|
+ nn.Conv2d(input_channels, 16, 5),
|
|
|
+ nn.MaxPool2d(2, 2),
|
|
|
+ nn.Conv2d(16, 32, 5),
|
|
|
+ nn.MaxPool2d(2, 2)
|
|
|
+ )
|
|
|
+
|
|
|
+ self.input_size = input_size
|
|
|
+ self.input_channels = input_channels
|
|
|
+ self._init_classifier(output_num)
|
|
|
+
|
|
|
+ def _init_classifier(self, output_num):
|
|
|
+ with torch.no_grad():
|
|
|
+ # Forward a dummy input through the feature extractor part of the network
|
|
|
+ dummy_input = torch.zeros(1, self.input_channels, self.input_size, self.input_size)
|
|
|
+ features_size = self.features(dummy_input).numel()
|
|
|
+
|
|
|
+ self.classifier = nn.Sequential(
|
|
|
+ nn.Linear(features_size, 120),
|
|
|
+ nn.Linear(120, 84),
|
|
|
+ nn.Linear(84, output_num)
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ x = self.features(x)
|
|
|
+ x = x.reshape(x.size(0), -1)
|
|
|
+ x = self.classifier(x)
|
|
|
+ return x
|
|
|
+
|
|
|
+ def get_encode_layers(self):
|
|
|
+ """
|
|
|
+ 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
|
|
|
+ """
|
|
|
+ conv_list = []
|
|
|
+ for module in self.modules():
|
|
|
+ if isinstance(module, nn.Conv2d):
|
|
|
+ conv_list.append(module)
|
|
|
+ return conv_list
|