|
@@ -3,7 +3,8 @@ import os
|
|
|
import tensorflow as tf
|
|
|
from keras.callbacks import ModelCheckpoint, CSVLogger
|
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
|
-from tensorflow.keras.applications import VGG16
|
|
|
+
|
|
|
+from models.VGG16 import create_model
|
|
|
|
|
|
|
|
|
def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
@@ -40,7 +41,7 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
|
|
|
|
def train_model(args, train_generator, val_generator):
|
|
|
# Create model
|
|
|
- model = VGG16(weights=None, include_top=True, input_shape=(224, 224, 3), classes=10)
|
|
|
+ model = create_model()
|
|
|
|
|
|
# 调整学习率
|
|
|
learning_rate = args.lr if args.lr else 1e-2
|