badnet.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class BadNet(nn.Module):
  5. def __init__(self, input_channels, output_num):
  6. super().__init__()
  7. self.conv1 = nn.Sequential(
  8. nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
  9. nn.BatchNorm2d(16), # 添加批量归一化
  10. nn.ReLU(),
  11. nn.AvgPool2d(kernel_size=2, stride=2)
  12. )
  13. self.conv2 = nn.Sequential(
  14. nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
  15. nn.BatchNorm2d(32), # 添加批量归一化
  16. nn.ReLU(),
  17. nn.AvgPool2d(kernel_size=2, stride=2)
  18. )
  19. # 计算全连接层的输入特征数
  20. fc1_input_features = 800 if input_channels == 3 else 512
  21. self.fc1 = nn.Sequential(
  22. nn.Linear(in_features=fc1_input_features, out_features=512),
  23. nn.ReLU()
  24. )
  25. self.fc2 = nn.Linear(in_features=512, out_features=output_num) # 移除 Softmax
  26. self.dropout = nn.Dropout(p=.5)
  27. def forward(self, x):
  28. x = self.conv1(x)
  29. x = self.conv2(x)
  30. x = x.view(x.size(0), -1) # 展平
  31. x = self.fc1(x)
  32. x = self.dropout(x) # 应用 dropout
  33. x = self.fc2(x)
  34. return x
  35. if __name__ == '__main__':
  36. import argparse
  37. parser = argparse.ArgumentParser(description='Badnet Implementation')
  38. parser.add_argument('--input_channels', default=3, type=int)
  39. parser.add_argument('--output_num', default=10, type=int)
  40. args = parser.parse_args()
  41. model = BadNet(args.input_channels, args.output_num)
  42. tensor = torch.rand(1, args.input_channels, 32, 32)
  43. pred = model(tensor)
  44. print(model)
  45. print("Predictions shape:", pred.shape)