|
@@ -8,6 +8,14 @@ from models.VGG16 import create_model
|
|
|
|
|
|
|
|
|
|
def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
|
|
+ # 给定的均值和标准差
|
|
|
|
+ mean = tf.constant([0.485, 0.456, 0.406])
|
|
|
|
+ std = tf.constant([0.229, 0.224, 0.225])
|
|
|
|
+
|
|
|
|
+ # 自定义标准化函数
|
|
|
|
+ def normalize_image(x):
|
|
|
|
+ return (x - mean) / std # 标准化
|
|
|
|
+
|
|
# 使用 ImageDataGenerator 加载图像并进行预处理
|
|
# 使用 ImageDataGenerator 加载图像并进行预处理
|
|
train_datagen = ImageDataGenerator(
|
|
train_datagen = ImageDataGenerator(
|
|
rescale=1.0 / 255.0, # 归一化
|
|
rescale=1.0 / 255.0, # 归一化
|
|
@@ -17,10 +25,14 @@ def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
shear_range=0.2,
|
|
shear_range=0.2,
|
|
zoom_range=0.2,
|
|
zoom_range=0.2,
|
|
horizontal_flip=True,
|
|
horizontal_flip=True,
|
|
- fill_mode='nearest'
|
|
|
|
|
|
+ fill_mode='nearest',
|
|
|
|
+ preprocessing_function=normalize_image # 使用自定义的标准化函数
|
|
)
|
|
)
|
|
|
|
|
|
- val_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
|
|
|
|
|
|
+ val_datagen = ImageDataGenerator(
|
|
|
|
+ rescale=1.0 / 255.0,
|
|
|
|
+ preprocessing_function=normalize_image # 使用自定义的标准化函数
|
|
|
|
+ )
|
|
|
|
|
|
train_generator = train_datagen.flow_from_directory(
|
|
train_generator = train_datagen.flow_from_directory(
|
|
train_dir,
|
|
train_dir,
|
|
@@ -85,7 +97,7 @@ def train_model(args, train_generator, val_generator):
|
|
|
|
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
# Define ModelCheckpoint callback to save weights for each epoch
|
|
checkpoint_callback = ModelCheckpoint(
|
|
checkpoint_callback = ModelCheckpoint(
|
|
- os.path.join(args.output_dir, 'vgg16_loss_{val_loss:.4f}_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
|
|
|
|
|
|
+ os.path.join(args.output_dir, 'vgg16_{epoch:03d}.h5'), # Save weights as vgg16_{epoch}.h5
|
|
save_weights_only=False,
|
|
save_weights_only=False,
|
|
monitor='val_loss', # Monitor the validation loss
|
|
monitor='val_loss', # Monitor the validation loss
|
|
save_freq='epoch', # Save after every epoch
|
|
save_freq='epoch', # Save after every epoch
|
|
@@ -120,7 +132,7 @@ def get_args_parser(add_help=True):
|
|
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
|
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
|
|
|
|
|
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
|
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
|
- parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
|
|
|
|
|
|
+ parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
|
|
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
|
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
|
|
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|