|
@@ -7,7 +7,6 @@ from models.AlexNet import create_model
|
|
from tensorflow.keras.preprocessing import image_dataset_from_directory
|
|
from tensorflow.keras.preprocessing import image_dataset_from_directory
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
|
|
def augment(image):
|
|
def augment(image):
|
|
# Random horizontal flip
|
|
# Random horizontal flip
|
|
@@ -78,7 +77,7 @@ def train_model(args, train_data, val_data):
|
|
# Select optimizer based on args.opt
|
|
# Select optimizer based on args.opt
|
|
if args.opt == 'sgd':
|
|
if args.opt == 'sgd':
|
|
optimizer = SGD(learning_rate=learning_rate,
|
|
optimizer = SGD(learning_rate=learning_rate,
|
|
- momentum=args.momentum if args.momentum else 0.0)
|
|
|
|
|
|
+ momentum=args.momentum if args.momentum else 0.0)
|
|
elif args.opt == 'adam':
|
|
elif args.opt == 'adam':
|
|
optimizer = Adam(learning_rate=learning_rate)
|
|
optimizer = Adam(learning_rate=learning_rate)
|
|
else:
|
|
else:
|
|
@@ -143,6 +142,7 @@ def get_args_parser(add_help=True):
|
|
)
|
|
)
|
|
return parser
|
|
return parser
|
|
|
|
|
|
|
|
+
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
args = get_args_parser().parse_args()
|
|
args = get_args_parser().parse_args()
|
|
|
|
|
|
@@ -154,7 +154,8 @@ if __name__ == "__main__":
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
|
# Load data
|
|
# Load data
|
|
- train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
|
|
|
|
|
|
+ train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size),
|
|
|
|
+ batch_size=args.batch_size)
|
|
|
|
|
|
# Start training
|
|
# Start training
|
|
- train_model(args, train_data, val_data)
|
|
|
|
|
|
+ train_model(args, train_data, val_data)
|