소스 검색

添加vgg16模型

liyan 10 달 전
부모
커밋
96b6f0cbef
2개의 변경된 파일7개의 추가작업 그리고 1개의 파일을 삭제
  1. 6 0
      block/model_get.py
  2. 1 1
      model/Alexnet.py

+ 6 - 0
block/model_get.py

@@ -5,6 +5,7 @@ choice_dict = {
     'LeNet': 'model_prepare(args).LeNet()',
     'Alexnet': 'model_prepare(args).Alexnet()',
     'VGG19': 'model_prepare(args).VGG19()',
+    'VGG16': 'model_prepare(args).VGG16()',
     'GoogleNet': 'model_prepare(args).GoogleNet()',
     'resnet': 'model_prepare(args).resnet()'
 }
@@ -122,6 +123,11 @@ class model_prepare:
         model = VGG19()
         return model
 
+    def VGG16(self):
+        from model.VGG19 import VGG16
+        model = VGG16()
+        return model
+
     def GoogleNet(self):
         from model.GoogleNet import GoogLeNet
         model = GoogLeNet(self.args.input_channels, self.args.output_num)

+ 1 - 1
model/Alexnet.py

@@ -53,7 +53,7 @@ class Alexnet(nn.Module):
         
     def forward(self, x):
         x = self.features(x)
-        x = x.view(x.size(0), -1)
+        x = x.reshape(x.size(0), -1)
         x = self.classifier(x)
         return x