123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- _cfg = {
- 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
- 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
- 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
- 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
- }
- def _make_layers(cfg, input_size):
- layers = []
- in_channels = 3
- for layer_cfg in cfg:
- if layer_cfg == 'M':
- layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
- input_size = input_size // 2
- else:
- layers.append(nn.Conv2d(in_channels=in_channels, out_channels=layer_cfg, kernel_size=3, stride=1, padding=1))
- layers.append(nn.BatchNorm2d(num_features=layer_cfg))
- layers.append(nn.ReLU(inplace=True))
- in_channels = layer_cfg
- return nn.Sequential(*layers), input_size
- class VGG(nn.Module):
- def __init__(self, name, input_size=32, num_classes=10):
- super(VGG, self).__init__()
- cfg = _cfg[name]
- self.features, final_size = _make_layers(cfg, input_size)
- self.fc = nn.Linear(512 * final_size * final_size, num_classes)
-
- def forward(self, x):
- x = self.features(x)
- x = x.reshape(x.size(0), -1)
- x = self.fc(x)
- return x
- def get_encode_layers(self):
- """
- 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
- """
- conv_list = []
- for module in self.modules():
- if isinstance(module, nn.Conv2d) and module.out_channels > 100:
- conv_list.append(module)
- return conv_list[1:3]
- def VGG11():
- return VGG('VGG11')
- def VGG13():
- return VGG('VGG13')
- def VGG16():
- return VGG('VGG16')
- def VGG19():
- return VGG('VGG19')
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='VGG Model Test')
- parser.add_argument('--input_channels', default=3, type=int)
- parser.add_argument('--output_num', default=10, type=int)
- parser.add_argument('--input_size', default=32, type=int)
- args = parser.parse_args()
-
- model = VGG19() # Changed to use VGG19
- tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
- pred = model(tensor)
-
- print(model)
- print("Predictions shape:", pred.shape)
|