Alexnet.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Alexnet(nn.Module):
  5. def __init__(self, input_channels, output_num, input_size):
  6. super().__init__()
  7. self.features = nn.Sequential(
  8. nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1),
  9. nn.BatchNorm2d(64), # 批量归一化层
  10. nn.MaxPool2d(kernel_size=2),
  11. nn.ReLU(inplace=True),
  12. nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
  13. nn.BatchNorm2d(192), # 批量归一化层
  14. nn.MaxPool2d(kernel_size=2),
  15. nn.ReLU(inplace=True),
  16. nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, padding=1),
  17. nn.BatchNorm2d(384), # 批量归一化层
  18. nn.ReLU(inplace=True),
  19. nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
  20. nn.BatchNorm2d(256), # 批量归一化层
  21. nn.ReLU(inplace=True),
  22. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
  23. nn.BatchNorm2d(256), # 批量归一化层
  24. nn.MaxPool2d(kernel_size=2),
  25. nn.ReLU(inplace=True),
  26. )
  27. self.input_size = input_size
  28. self._init_classifier(output_num)
  29. def _init_classifier(self, output_num):
  30. with torch.no_grad():
  31. # Forward a dummy input through the feature extractor part of the network
  32. dummy_input = torch.zeros(1, 3, self.input_size, self.input_size)
  33. features_size = self.features(dummy_input).numel()
  34. self.classifier = nn.Sequential(
  35. nn.Dropout(0.5),
  36. nn.Linear(features_size, 1000),
  37. nn.ReLU(inplace=True),
  38. nn.Dropout(0.5),
  39. nn.Linear(1000, 256),
  40. nn.ReLU(inplace=True),
  41. nn.Linear(256, output_num)
  42. )
  43. def forward(self, x):
  44. x = self.features(x)
  45. x = x.reshape(x.size(0), -1)
  46. x = self.classifier(x)
  47. return x
  48. def get_encode_layers(self):
  49. """
  50. 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
  51. """
  52. conv_list = []
  53. for module in self.modules():
  54. if isinstance(module, nn.Conv2d):
  55. conv_list.append(module)
  56. return conv_list[0:2]
  57. if __name__ == '__main__':
  58. import argparse
  59. parser = argparse.ArgumentParser(description='AlexNet Implementation')
  60. parser.add_argument('--input_channels', default=3, type=int)
  61. parser.add_argument('--output_num', default=10, type=int)
  62. parser.add_argument('--input_size', default=32, type=int)
  63. args = parser.parse_args()
  64. model = Alexnet(args.input_channels, args.output_num, args.input_size)
  65. tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
  66. pred = model(tensor)
  67. print(model)
  68. print("Predictions shape:", pred.shape)