liyan пре 11 месеци
родитељ
комит
7db67fc5a6
1 измењених фајлова са 76 додато и 0 уклоњено
  1. 76 0
      tests/model/Alexnet.py

+ 76 - 0
tests/model/Alexnet.py

@@ -0,0 +1,76 @@
+import torch
+import torch.nn as nn
+
+
+class Alexnet(nn.Module):
+    def __init__(self, input_channels, output_num, input_size):
+        super().__init__()
+
+        self.features = nn.Sequential(
+            nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1),
+            nn.BatchNorm2d(64),  # 批量归一化层
+            nn.MaxPool2d(kernel_size=2),
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
+            nn.BatchNorm2d(192),  # 批量归一化层
+            nn.MaxPool2d(kernel_size=2),
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, padding=1),
+            nn.BatchNorm2d(384),  # 批量归一化层
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
+            nn.BatchNorm2d(256),  # 批量归一化层
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
+            nn.BatchNorm2d(256),  # 批量归一化层
+            nn.MaxPool2d(kernel_size=2),
+            nn.ReLU(inplace=True),
+        )
+
+        self.input_size = input_size
+        self._init_classifier(output_num)
+
+    def _init_classifier(self, output_num):
+        with torch.no_grad():
+            # Forward a dummy input through the feature extractor part of the network
+            dummy_input = torch.zeros(1, 3, self.input_size, self.input_size)
+            features_size = self.features(dummy_input).numel()
+
+        self.classifier = nn.Sequential(
+            nn.Dropout(0.5),
+            nn.Linear(features_size, 1000),
+            nn.ReLU(inplace=True),
+
+            nn.Dropout(0.5),
+            nn.Linear(1000, 256),
+            nn.ReLU(inplace=True),
+
+            nn.Linear(256, output_num)
+        )
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.reshape(x.size(0), -1)
+        x = self.classifier(x)
+        return x
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='AlexNet Implementation')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    parser.add_argument('--input_size', default=32, type=int)
+    args = parser.parse_args()
+
+    model = Alexnet(args.input_channels, args.output_num, args.input_size)
+    tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
+    pred = model(tensor)
+
+    print(model)
+    print("Predictions shape:", pred.shape)