resnet.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class CommonBlock(nn.Module):
  5. """ Standard residual block without downsampling. """
  6. def __init__(self, in_channel, out_channel, stride=1):
  7. super(CommonBlock, self).__init__()
  8. self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
  9. self.bn1 = nn.BatchNorm2d(out_channel)
  10. self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
  11. self.bn2 = nn.BatchNorm2d(out_channel)
  12. def forward(self, x):
  13. identity = x
  14. x = F.relu(self.bn1(self.conv1(x)), inplace=True)
  15. x = self.bn2(self.conv2(x))
  16. x += identity
  17. return F.relu(x, inplace=True)
  18. class SpecialBlock(nn.Module):
  19. """ Residual block with downsampling and channel size increase. """
  20. def __init__(self, in_channel, out_channel, stride):
  21. super(SpecialBlock, self).__init__()
  22. self.change_channel = nn.Sequential(
  23. nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False),
  24. nn.BatchNorm2d(out_channel)
  25. )
  26. self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
  27. self.bn1 = nn.BatchNorm2d(out_channel)
  28. self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
  29. self.bn2 = nn.BatchNorm2d(out_channel)
  30. def forward(self, x):
  31. identity = self.change_channel(x)
  32. x = F.relu(self.bn1(self.conv1(x)), inplace=True)
  33. x = self.bn2(self.conv2(x))
  34. x += identity
  35. return F.relu(x, inplace=True)
  36. class ResNet18(nn.Module):
  37. def __init__(self, input_channels, num_classes=10):
  38. super(ResNet18, self).__init__()
  39. self.prepare = nn.Sequential(
  40. nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
  41. nn.BatchNorm2d(64),
  42. nn.ReLU(inplace=True),
  43. nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  44. )
  45. self.layer1 = nn.Sequential(CommonBlock(64, 64), CommonBlock(64, 64))
  46. self.layer2 = nn.Sequential(SpecialBlock(64, 128, 2), CommonBlock(128, 128))
  47. self.layer3 = nn.Sequential(SpecialBlock(128, 256, 2), CommonBlock(256, 256))
  48. self.layer4 = nn.Sequential(SpecialBlock(256, 512, 2), CommonBlock(512, 512))
  49. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  50. self.fc = nn.Linear(512, num_classes)
  51. def forward(self, x):
  52. x = self.prepare(x)
  53. x = self.layer1(x)
  54. x = self.layer2(x)
  55. x = self.layer3(x)
  56. x = self.layer4(x)
  57. x = self.avgpool(x)
  58. x = x.reshape(x.size(0), -1)
  59. x = self.fc(x)
  60. return x
  61. def get_encode_layers(self):
  62. """
  63. 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
  64. """
  65. conv_list = []
  66. for module in self.modules():
  67. if isinstance(module, nn.Conv2d) and module.out_channels > 100:
  68. conv_list.append(module)
  69. return conv_list[2:4]
  70. if __name__ == '__main__':
  71. import argparse
  72. parser = argparse.ArgumentParser(description='Resnet Implementation')
  73. parser.add_argument('--input_channels', default=3, type=int)
  74. parser.add_argument('--output_num', default=10, type=int)
  75. args = parser.parse_args()
  76. model = ResNet18(args.input_channels, args.output_num)
  77. tensor = torch.rand(1, args.input_channels, 224, 224)
  78. pred = model(tensor)
  79. print(model)
  80. print("Predictions shape:", pred.shape)