|
@@ -5,6 +5,7 @@ choice_dict = {
|
|
'LeNet': 'model_prepare(args).LeNet()',
|
|
'LeNet': 'model_prepare(args).LeNet()',
|
|
'Alexnet': 'model_prepare(args).Alexnet()',
|
|
'Alexnet': 'model_prepare(args).Alexnet()',
|
|
'VGG19': 'model_prepare(args).VGG19()',
|
|
'VGG19': 'model_prepare(args).VGG19()',
|
|
|
|
+ 'VGG16': 'model_prepare(args).VGG16()',
|
|
'GoogleNet': 'model_prepare(args).GoogleNet()',
|
|
'GoogleNet': 'model_prepare(args).GoogleNet()',
|
|
'resnet': 'model_prepare(args).resnet()'
|
|
'resnet': 'model_prepare(args).resnet()'
|
|
}
|
|
}
|
|
@@ -122,6 +123,11 @@ class model_prepare:
|
|
model = VGG19()
|
|
model = VGG19()
|
|
return model
|
|
return model
|
|
|
|
|
|
|
|
+ def VGG16(self):
|
|
|
|
+ from model.VGG19 import VGG16
|
|
|
|
+ model = VGG16()
|
|
|
|
+ return model
|
|
|
|
+
|
|
def GoogleNet(self):
|
|
def GoogleNet(self):
|
|
from model.GoogleNet import GoogLeNet
|
|
from model.GoogleNet import GoogLeNet
|
|
model = GoogLeNet(self.args.input_channels, self.args.output_num)
|
|
model = GoogLeNet(self.args.input_channels, self.args.output_num)
|