VGG19.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. _cfg = {
  5. 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  6. 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  7. 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  8. 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  9. }
  10. def _make_layers(cfg, input_size):
  11. layers = []
  12. in_channels = 3
  13. for layer_cfg in cfg:
  14. if layer_cfg == 'M':
  15. layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
  16. input_size = input_size // 2
  17. else:
  18. layers.append(nn.Conv2d(in_channels=in_channels, out_channels=layer_cfg, kernel_size=3, stride=1, padding=1))
  19. layers.append(nn.BatchNorm2d(num_features=layer_cfg))
  20. layers.append(nn.ReLU(inplace=True))
  21. in_channels = layer_cfg
  22. return nn.Sequential(*layers), input_size
  23. class VGG(nn.Module):
  24. def __init__(self, name, input_size=32, num_classes=10):
  25. super(VGG, self).__init__()
  26. cfg = _cfg[name]
  27. self.features, final_size = _make_layers(cfg, input_size)
  28. self.fc = nn.Linear(512 * final_size * final_size, num_classes)
  29. def forward(self, x):
  30. x = self.features(x)
  31. x = x.reshape(x.size(0), -1)
  32. x = self.fc(x)
  33. return x
  34. def get_encode_layers(self):
  35. """
  36. 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
  37. """
  38. conv_list = []
  39. for module in self.modules():
  40. if isinstance(module, nn.Conv2d) and module.out_channels > 100:
  41. conv_list.append(module)
  42. return conv_list[1:3]
  43. def VGG11():
  44. return VGG('VGG11')
  45. def VGG13():
  46. return VGG('VGG13')
  47. def VGG16():
  48. return VGG('VGG16')
  49. def VGG19():
  50. return VGG('VGG19')
  51. if __name__ == '__main__':
  52. import argparse
  53. parser = argparse.ArgumentParser(description='VGG Model Test')
  54. parser.add_argument('--input_channels', default=3, type=int)
  55. parser.add_argument('--output_num', default=10, type=int)
  56. parser.add_argument('--input_size', default=32, type=int)
  57. args = parser.parse_args()
  58. model = VGG19() # Changed to use VGG19
  59. tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
  60. pred = model(tensor)
  61. print(model)
  62. print("Predictions shape:", pred.shape)