فهرست منبع

调整项目代码

liyan 7 ماه پیش
والد
کامیت
a719fd182d
2فایلهای تغییر یافته به همراه7 افزوده شده و 2 حذف شده
  1. 4 0
      models/VGG16.py
  2. 3 2
      train_vgg16.py

+ 4 - 0
models/VGG16.py

@@ -0,0 +1,4 @@
+from tensorflow.keras.applications import VGG16
+
+def create_model(input_shape=(224, 224, 3), num_classes=10):
+    return VGG16(weights=None, include_top=True, input_shape=input_shape, classes=num_classes)

+ 3 - 2
train_vgg16.py

@@ -3,7 +3,8 @@ import os
 import tensorflow as tf
 from keras.callbacks import ModelCheckpoint, CSVLogger
 from tensorflow.keras.preprocessing.image import ImageDataGenerator
-from tensorflow.keras.applications import VGG16
+
+from models.VGG16 import create_model
 
 
 def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
@@ -40,7 +41,7 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
 
 def train_model(args, train_generator, val_generator):
     # Create model
-    model = VGG16(weights=None, include_top=True, input_shape=(224, 224, 3), classes=10)
+    model = create_model()
 
     # 调整学习率
     learning_rate = args.lr if args.lr else 1e-2