|
@@ -61,10 +61,19 @@ class ResNet18(nn.Module):
|
|
|
x = self.layer3(x)
|
|
|
x = self.layer4(x)
|
|
|
x = self.avgpool(x)
|
|
|
- x = x.view(x.size(0), -1)
|
|
|
+ x = x.reshape(x.size(0), -1)
|
|
|
x = self.fc(x)
|
|
|
return x
|
|
|
|
|
|
+ def get_encode_layers(self):
|
|
|
+ """
|
|
|
+ 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
|
|
|
+ """
|
|
|
+ conv_list = []
|
|
|
+ for module in self.modules():
|
|
|
+ if isinstance(module, nn.Conv2d) and module.out_channels > 100:
|
|
|
+ conv_list.append(module)
|
|
|
+ return conv_list[2:4]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
import argparse
|