Browse Source

实现LeNet白盒水印集成

liyan 1 year ago
parent
commit
f5af0ecbbf
2 changed files with 54 additions and 2 deletions
  1. 8 2
      block/model_get.py
  2. 46 0
      model/LeNet.py

+ 8 - 2
block/model_get.py

@@ -10,7 +10,8 @@ choice_dict = {
     'mobilenetv2': 'model_prepare(args).mobilenetv2()',
     'resnet': 'model_prepare(args).resnet()',
     'VGG19': 'model_prepare(args).VGG19()',
-    'efficientnet': 'model_prepare(args).EfficientNetV2_S()'
+    'efficientnet': 'model_prepare(args).EfficientNetV2_S()',
+    'LeNet': 'model_prepare(args).LeNet()'
 }
 
 
@@ -157,4 +158,9 @@ class model_prepare:
     def EfficientNetV2_S(self):
         from model.efficientnet import EfficientNetV2_S
         model = EfficientNetV2_S(self.args.input_channels, self.args.output_num)
-        return model
+        return model
+
+    def LeNet(self):
+        from model.LeNet import LeNet
+        model = LeNet(self.args.input_channels, self.args.output_num, self.args.input_size)
+        return model

+ 46 - 0
model/LeNet.py

@@ -0,0 +1,46 @@
+import torch
+import torch.nn as nn
+
+
+class LeNet(nn.Module):
+    def __init__(self, input_channels, output_num, input_size):
+        super(LeNet, self).__init__()
+
+        self.features = nn.Sequential(
+            nn.Conv2d(input_channels, 16, 5),
+            nn.MaxPool2d(2, 2),
+            nn.Conv2d(16, 32, 5),
+            nn.MaxPool2d(2, 2)
+        )
+
+        self.input_size = input_size
+        self.input_channels = input_channels
+        self._init_classifier(output_num)
+
+    def _init_classifier(self, output_num):
+        with torch.no_grad():
+            # Forward a dummy input through the feature extractor part of the network
+            dummy_input = torch.zeros(1, self.input_channels, self.input_size, self.input_size)
+            features_size = self.features(dummy_input).numel()
+
+        self.classifier = nn.Sequential(
+            nn.Linear(features_size, 120),
+            nn.Linear(120, 84),
+            nn.Linear(84, output_num)
+        )
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.reshape(x.size(0), -1)
+        x = self.classifier(x)
+        return x
+
+    def get_encode_layers(self):
+        """
+        获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
+        """
+        conv_list = []
+        for module in self.modules():
+            if isinstance(module, nn.Conv2d):
+                conv_list.append(module)
+        return conv_list