123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class BadNet(nn.Module):
- def __init__(self, input_channels, output_num):
- super().__init__()
- self.conv1 = nn.Sequential(
- nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
- nn.BatchNorm2d(16), # 添加批量归一化
- nn.ReLU(),
- nn.AvgPool2d(kernel_size=2, stride=2)
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
- nn.BatchNorm2d(32), # 添加批量归一化
- nn.ReLU(),
- nn.AvgPool2d(kernel_size=2, stride=2)
- )
- # 计算全连接层的输入特征数
- fc1_input_features = 800 if input_channels == 3 else 512
- self.fc1 = nn.Sequential(
- nn.Linear(in_features=fc1_input_features, out_features=512),
- nn.ReLU()
- )
- self.fc2 = nn.Linear(in_features=512, out_features=output_num) # 移除 Softmax
- self.dropout = nn.Dropout(p=.5)
- def forward(self, x):
- x = self.conv1(x)
- x = self.conv2(x)
- x = x.view(x.size(0), -1) # 展平
- x = self.fc1(x)
- x = self.dropout(x) # 应用 dropout
- x = self.fc2(x)
- return x
- if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser(description='Badnet Implementation')
- parser.add_argument('--input_channels', default=3, type=int)
- parser.add_argument('--output_num', default=10, type=int)
- args = parser.parse_args()
-
- model = BadNet(args.input_channels, args.output_num)
- tensor = torch.rand(1, args.input_channels, 32, 32)
- pred = model(tensor)
-
- print(model)
- print("Predictions shape:", pred.shape)
|