Prechádzať zdrojové kódy

VGG19集成白盒水印

liyan 1 rok pred
rodič
commit
bf0ccfaab8
1 zmenil súbory, kde vykonal 11 pridanie a 1 odobranie
  1. 11 1
      model/VGG19.py

+ 11 - 1
model/VGG19.py

@@ -32,10 +32,20 @@ class VGG(nn.Module):
         
     def forward(self, x):
         x = self.features(x)
-        x = x.view(x.size(0), -1)
+        x = x.reshape(x.size(0), -1)
         x = self.fc(x)
         return x
 
+    def get_encode_layers(self):
+        """
+        获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
+        """
+        conv_list = []
+        for module in self.modules():
+            if isinstance(module, nn.Conv2d) and module.out_channels > 100:
+                conv_list.append(module)
+        return conv_list[1:3]
+
 def VGG11():
     return VGG('VGG11')