Browse Source

修改获取模型代码及添加LeNet模型代码

liyan 1 year ago
parent
commit
2f9da790a0
10 changed files with 65 additions and 336 deletions
  1. 12 39
      block/model_get.py
  2. 0 1
      model/Alexnet.py
  3. 0 1
      model/GoogleNet.py
  4. 53 0
      model/LeNet.py
  5. 0 1
      model/VGG19.py
  6. 0 54
      model/badnet.py
  7. 0 94
      model/mobilenetv2.py
  8. 0 13
      model/test.py
  9. 0 46
      model/timm_model.py
  10. 0 87
      model/yolov7_cls.py

+ 12 - 39
block/model_get.py

@@ -2,15 +2,11 @@ import os
 import torch
 
 choice_dict = {
-    'yolov7_cls': 'model_prepare(args).yolov7_cls()',
-    'timm_model': 'model_prepare(args).timm_model()',
+    'LeNet': 'model_prepare(args).LeNet()',
     'Alexnet': 'model_prepare(args).Alexnet()',
-    'badnet': 'model_prepare(args).badnet()',
-    'GoogleNet': 'model_prepare(args).GoogleNet()',
-    'mobilenetv2': 'model_prepare(args).mobilenetv2()',
-    'resnet': 'model_prepare(args).resnet()',
     'VGG19': 'model_prepare(args).VGG19()',
-    'efficientnet': 'model_prepare(args).EfficientNetV2_S()'
+    'GoogleNet': 'model_prepare(args).GoogleNet()',
+    'resnet': 'model_prepare(args).resnet()'
 }
 
 
@@ -22,9 +18,6 @@ def model_get(args):
             model_dict = torch.load(args.prune_weight, map_location='cpu')
             model = model_dict['model']
             model = prune(args, model)
-        elif args.timm:
-            # model = model_prepare(args).timm_model()
-            model = eval(choice_dict['timm_model'])
         else:
             model = eval(choice_dict[args.model])
         model_dict = {}
@@ -114,47 +107,27 @@ class model_prepare:
     def __init__(self, args):
         self.args = args
 
-    def timm_model(self):
-        from model.timm_model import timm_model
-        model = timm_model(self.args)
+    def LeNet(self):
+        from model.LeNet import LeNet
+        model = LeNet(self.args.input_channels, self.args.output_num, self.args.input_size)
         return model
 
-    def yolov7_cls(self):
-        from model.yolov7_cls import yolov7_cls
-        model = yolov7_cls(self.args)
-        return model
-    
     def Alexnet(self):
         from model.Alexnet import Alexnet
         model = Alexnet(self.args.input_channels, self.args.output_num, self.args.input_size)
         return model
-    
-    def badnet(self):
-        from model.badnet import BadNet
-        model = BadNet(self.args.input_channels, self.args.output_num)
+
+    def VGG19(self):
+        from model.VGG19 import VGG19
+        model = VGG19()
         return model
-    
+
     def GoogleNet(self):
         from model.GoogleNet import GoogLeNet
         model = GoogLeNet(self.args.input_channels, self.args.output_num)
         return model
-    
-    def mobilenetv2(self):
-        from model.mobilenetv2 import MobileNetV2
-        model = MobileNetV2(self.args.input_channels, self.args.output_num)
-        return model
-    
+
     def resnet(self):
         from model.resnet import ResNet18
         model = ResNet18(self.args.input_channels, self.args.output_num)
         return model
-    
-    def VGG19(self):
-        from model.VGG19 import VGG19
-        model = VGG19()
-        return model
-    
-    def EfficientNetV2_S(self):
-        from model.efficientnet import EfficientNetV2_S
-        model = EfficientNetV2_S(self.args.input_channels, self.args.output_num)
-        return model

+ 0 - 1
model/Alexnet.py

@@ -1,6 +1,5 @@
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 
 class Alexnet(nn.Module):
     def __init__(self, input_channels, output_num, input_size):

+ 0 - 1
model/GoogleNet.py

@@ -1,6 +1,5 @@
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 
 class Inception(nn.Module):
     def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):

+ 53 - 0
model/LeNet.py

@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+
+
+class LeNet(nn.Module):
+    def __init__(self, input_channels, output_num, input_size):
+        super(LeNet, self).__init__()
+
+        self.features = nn.Sequential(
+            nn.Conv2d(input_channels, 16, 5),
+            nn.MaxPool2d(2, 2),
+            nn.Conv2d(16, 32, 5),
+            nn.MaxPool2d(2, 2)
+        )
+
+        self.input_size = input_size
+        self.input_channels = input_channels
+        self._init_classifier(output_num)
+
+    def _init_classifier(self, output_num):
+        with torch.no_grad():
+            # Forward a dummy input through the feature extractor part of the network
+            dummy_input = torch.zeros(1, self.input_channels, self.input_size, self.input_size)
+            features_size = self.features(dummy_input).numel()
+
+        self.classifier = nn.Sequential(
+            nn.Linear(features_size, 120),
+            nn.Linear(120, 84),
+            nn.Linear(84, output_num)
+        )
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.reshape(x.size(0), -1)
+        x = self.classifier(x)
+        return x
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='LeNet Implementation')
+    parser.add_argument('--input_channels', default=3, type=int)
+    parser.add_argument('--output_num', default=10, type=int)
+    parser.add_argument('--input_size', default=32, type=int)
+    args = parser.parse_args()
+
+    model = LeNet(args.input_channels, args.output_num, args.input_size)
+    tensor = torch.rand(1, args.input_channels, args.input_size, args.input_size)
+    pred = model(tensor)
+
+    print(model)
+    print("Predictions shape:", pred.shape)

+ 0 - 1
model/VGG19.py

@@ -1,6 +1,5 @@
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 
 _cfg = {
     'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],

+ 0 - 54
model/badnet.py

@@ -1,54 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-class BadNet(nn.Module):
-
-    def __init__(self, input_channels, output_num):
-        super().__init__()
-        self.conv1 = nn.Sequential(
-            nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
-            nn.BatchNorm2d(16),  # 添加批量归一化
-            nn.ReLU(),
-            nn.AvgPool2d(kernel_size=2, stride=2)
-        )
-
-        self.conv2 = nn.Sequential(
-            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
-            nn.BatchNorm2d(32),  # 添加批量归一化
-            nn.ReLU(),
-            nn.AvgPool2d(kernel_size=2, stride=2)
-        )
-        # 计算全连接层的输入特征数
-        fc1_input_features = 800 if input_channels == 3 else 512
-        self.fc1 = nn.Sequential(
-            nn.Linear(in_features=fc1_input_features, out_features=512),
-            nn.ReLU()
-        )
-        self.fc2 = nn.Linear(in_features=512, out_features=output_num)  # 移除 Softmax
-        self.dropout = nn.Dropout(p=.5)
-
-    def forward(self, x):
-        x = self.conv1(x)
-        x = self.conv2(x)
-
-        x = x.view(x.size(0), -1)  # 展平
-        x = self.fc1(x)
-        x = self.dropout(x)  # 应用 dropout
-        x = self.fc2(x)
-        return x
-
-if __name__ == '__main__':
-    import argparse
-
-    parser = argparse.ArgumentParser(description='Badnet Implementation')
-    parser.add_argument('--input_channels', default=3, type=int)
-    parser.add_argument('--output_num', default=10, type=int)
-    args = parser.parse_args()
-    
-    model = BadNet(args.input_channels, args.output_num)
-    tensor = torch.rand(1, args.input_channels, 32, 32)
-    pred = model(tensor)
-    
-    print(model)
-    print("Predictions shape:", pred.shape)

+ 0 - 94
model/mobilenetv2.py

@@ -1,94 +0,0 @@
-'''MobileNetV2 in PyTorch.
-See the paper "Inverted Residuals and Linear Bottlenecks:
-Mobile Networks for Classification, Detection and Segmentation" for more details.
-'''
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class Block(nn.Module):
-    '''expand + depthwise + pointwise'''
-    def __init__(self, in_planes, out_planes, expansion, stride):
-        super(Block, self).__init__()
-        self.stride = stride
-
-        planes = expansion * in_planes
-        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
-        self.bn3 = nn.BatchNorm2d(out_planes)
-
-        self.shortcut = nn.Sequential()
-        if stride == 1 and in_planes != out_planes:
-            self.shortcut = nn.Sequential(
-                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
-                nn.BatchNorm2d(out_planes),
-            )
-
-    def forward(self, x):
-        out = F.relu(self.bn1(self.conv1(x)))
-        out = F.relu(self.bn2(self.conv2(out)))
-        out = self.bn3(self.conv3(out))
-        out = out + self.shortcut(x) if self.stride==1 else out
-        return out
-
-
-class MobileNetV2(nn.Module):
-    # (expansion, out_planes, num_blocks, stride)
-    cfg = [(1,  16, 1, 1),
-           (6,  24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10
-           (6,  32, 3, 2),
-           (6,  64, 4, 2),
-           (6,  96, 3, 1),
-           (6, 160, 3, 2),
-           (6, 320, 1, 1)]
-
-    def __init__(self, input_channels, output_num):
-        super(MobileNetV2, self).__init__()
-        # NOTE: change conv1 stride 2 -> 1 for CIFAR10
-        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1, bias=False)
-        self.bn1 = nn.BatchNorm2d(32)
-        self.layers = self._make_layers(in_planes=32)
-        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
-        self.bn2 = nn.BatchNorm2d(1280)
-        self.linear = nn.Linear(1280, output_num)
-
-    def _make_layers(self, in_planes):
-        layers = []
-        for expansion, out_planes, num_blocks, stride in self.cfg:
-            strides = [stride] + [1]*(num_blocks-1)
-            for stride in strides:
-                layers.append(Block(in_planes, out_planes, expansion, stride))
-                in_planes = out_planes
-        return nn.Sequential(*layers)
-
-    def forward(self, x):
-        out = F.relu(self.bn1(self.conv1(x)))
-        out = self.layers(out)
-        out = F.relu(self.bn2(self.conv2(out)))
-        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
-        out = F.avg_pool2d(out, 4)
-        out = out.view(out.size(0), -1)
-        out = self.linear(out)
-        return out
-
-
-
-if __name__ == '__main__':
-    import argparse
-
-    parser = argparse.ArgumentParser(description='MobileNetV2 Implementation')
-    parser.add_argument('--input_channels', default=3, type=int)
-    parser.add_argument('--output_num', default=10, type=int)
-    # parser.add_argument('--input_size', default=32, type=int)
-    args = parser.parse_args()
-    
-    model = MobileNetV2(args.input_channels, args.output_num)
-    tensor = torch.rand(1, args.input_channels, 32, 32)
-    pred = model(tensor)
-    
-    print(model)
-    print("Predictions shape:", pred.shape)

+ 0 - 13
model/test.py

@@ -1,13 +0,0 @@
-import os
-import sys
-
-project_root = '/home/yhsun/classification-main/'
-sys.path.append(project_root)
-print("Project root added to sys.path:", project_root)
-
-# Verify that we can access the model package directly
-import model
-print("Model package is accessible, path:", model.__file__)
-
-from model.layer import linear_head
-print("Imported linear_head from model.layer")

+ 0 - 46
model/timm_model.py

@@ -1,46 +0,0 @@
-import timm
-# print(timm.list_models())
-import torch
-# from model.layer import linear_head
-
-import os
-import sys
-
-project_root = '/home/yhsun/classification-main/'
-sys.path.append(project_root)
-# print("Project root added to sys.path:", project_root)
-
-# Verify that we can access the model package directly
-import model
-# print("Model package is accessible, path:", model.__file__)
-
-from model.layer import linear_head
-# print("Imported linear_head from model.layer")
-
-
-class timm_model(torch.nn.Module):
-    def __init__(self, args):
-        super().__init__()
-        self.backbone = timm.create_model(args.model, in_chans=3, features_only=True, exportable=True)
-        out_dim = self.backbone.feature_info.channels()[-1]  # backbone输出有多个,接最后一个输出,并得到其通道数
-        self.linear_head = linear_head(out_dim, args.output_class)
-
-    def forward(self, x):
-        x = self.backbone(x)
-        x = self.linear_head(x[-1])
-        return x
-
-
-if __name__ == '__main__':
-    import argparse
-
-    parser = argparse.ArgumentParser(description='')
-    parser.add_argument('--model', default='resnet18', type=str)
-    parser.add_argument('--input_size', default=32, type=int)
-    parser.add_argument('--output_class', default=10, type=int)
-    args = parser.parse_args()
-    model = timm_model(args)
-    tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
-    pred = model(tensor)
-    print(model)
-    print(pred.shape)

+ 0 - 87
model/yolov7_cls.py

@@ -1,87 +0,0 @@
-# 根据yolov7改编:https://github.com/WongKinYiu/yolov7
-import torch
-import os
-import sys
-
-project_root = '/home/yhsun/classification-main/'
-sys.path.append(project_root)
-# print("Project root added to sys.path:", project_root)
-
-# Verify that we can access the model package directly
-import model
-from model.layer import cbs, elan, mp, sppcspc, linear_head
-
-
-class yolov7_cls(torch.nn.Module):
-    def __init__(self, args):
-        super().__init__()
-        dim_dict = {'n': 8, 's': 16, 'm': 32, 'l': 64}
-        n_dict = {'n': 1, 's': 1, 'm': 2, 'l': 3}
-        dim = dim_dict[args.model_type]
-        n = n_dict[args.model_type]
-        output_class = args.output_class
-        # 网络结构
-        if not args.prune:  # 正常版本
-            self.l0 = cbs(3, dim, 1, 1)
-            self.l1 = cbs(dim, 2 * dim, 3, 2)  # input_size/2
-            self.l2 = cbs(2 * dim, 2 * dim, 1, 1)
-            self.l3 = cbs(2 * dim, 4 * dim, 3, 2)  # input_size/4
-            self.l4 = elan(4 * dim, 8 * dim, n)
-            self.l5 = mp(8 * dim, 8 * dim)  # input_size/8
-            self.l6 = elan(8 * dim, 16 * dim, n)
-            self.l7 = mp(16 * dim, 16 * dim)  # input_size/16
-            self.l8 = elan(16 * dim, 32 * dim, n)
-            self.l9 = mp(32 * dim, 32 * dim)  # input_size/32
-            self.l10 = elan(32 * dim, 32 * dim, n)
-            self.l11 = sppcspc(32 * dim, 16 * dim)
-            self.l12 = cbs(16 * dim, 8 * dim, 1, 1)
-            self.linear_head = linear_head(8 * dim, output_class)
-        else:  # 剪枝版本
-            config = args.prune_num
-            self.l0 = cbs(3, config[0], 1, 1)
-            self.l1 = cbs(config[0], config[1], 3, 2)  # input_size/2
-            self.l2 = cbs(config[1], config[2], 1, 1)
-            self.l3 = cbs(config[2], config[3], 3, 2)  # input_size/4
-            self.l4 = elan(config[3], None, n, config[4:7 + 2 * n])
-            self.l5 = mp(config[6 + 2 * n], None, config[7 + 2 * n:10 + 2 * n])  # input_size/8
-            self.l6 = elan(config[7 + 2 * n] + config[9 + 2 * n], None, n, config[10 + 2 * n:13 + 4 * n])
-            self.l7 = mp(config[12 + 4 * n], None, config[13 + 4 * n:16 + 4 * n])  # input_size/16
-            self.l8 = elan(config[13 + 4 * n] + config[15 + 4 * n], None, n, config[16 + 4 * n:19 + 6 * n])
-            self.l9 = mp(config[18 + 6 * n], None, config[19 + 6 * n:22 + 6 * n])  # input_size/32
-            self.l10 = elan(config[19 + 6 * n] + config[21 + 6 * n], None, n, config[22 + 6 * n:25 + 8 * n])
-            self.l11 = sppcspc(config[24 + 8 * n], None, config[25 + 8 * n:32 + 8 * n])
-            self.l12 = cbs(config[31 + 8 * n], config[32 + 8 * n], 1, 1)
-            self.linear_head = linear_head(config[32 + 8 * n], output_class)
-
-    def forward(self, x):
-        x = self.l0(x)
-        x = self.l1(x)
-        x = self.l2(x)
-        x = self.l3(x)
-        x = self.l4(x)
-        x = self.l5(x)
-        x = self.l6(x)
-        x = self.l7(x)
-        x = self.l8(x)
-        x = self.l9(x)
-        x = self.l10(x)
-        x = self.l11(x)
-        x = self.l12(x)
-        x = self.linear_head(x)
-        return x
-
-
-if __name__ == '__main__':
-    import argparse
-
-    parser = argparse.ArgumentParser(description='')
-    parser.add_argument('--prune', default=False, type=bool)
-    parser.add_argument('--model_type', default='n', type=str)
-    parser.add_argument('--input_size', default=32, type=int)
-    parser.add_argument('--output_class', default=10, type=int)
-    args = parser.parse_args()
-    model = yolov7_cls(args)
-    tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
-    pred = model(tensor)
-    print(model)
-    print(pred.shape)