瀏覽代碼

解决vgg输入参数大小问题

liyan 9 月之前
父節點
當前提交
1d4341fb9e
共有 3 個文件被更改,包括 4 次插入4 次删除
  1. 1 1
      block/model_get.py
  2. 2 2
      model/VGG19.py
  3. 1 1
      train.py

+ 1 - 1
block/model_get.py

@@ -125,7 +125,7 @@ class model_prepare:
 
     def VGG16(self):
         from model.VGG19 import VGG16
-        model = VGG16()
+        model = VGG16(self.args.input_size)
         return model
 
     def GoogleNet(self):

+ 2 - 2
model/VGG19.py

@@ -41,8 +41,8 @@ def VGG11():
 def VGG13():
     return VGG('VGG13')
 
-def VGG16():
-    return VGG('VGG16')
+def VGG16(input_size):
+    return VGG(name='VGG16', input_size=input_size)
 
 def VGG19():
     return VGG('VGG19')

+ 1 - 1
train.py

@@ -115,7 +115,7 @@ if args.local_rank == 0:
     elif args.prune:
         print(f'| 加载模型+剪枝训练:{args.prune_weight} |')
     else:  # 创建自定义模型args.model
-        assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
+        # assert os.path.exists(f'model/{args.model}.py'), f'! 没有自定义模型:{args.model} !'
         print(f'| 创建自定义模型:{args.model} |')
 # -------------------------------------------------------------------------------------------------------------------- #
 if __name__ == '__main__':