|
@@ -67,10 +67,18 @@ def train_model(args, train_data, val_data):
|
|
|
|
|
|
# 调整学习率
|
|
|
learning_rate = args.lr if args.lr else 1e-2
|
|
|
- # optimizer = SGD(learning_rate=learning_rate, momentum=args.momentum)
|
|
|
+
|
|
|
+ # Select optimizer based on args.opt
|
|
|
+ if args.opt == 'sgd':
|
|
|
+ optimizer = SGD(learning_rate=learning_rate,
|
|
|
+ momentum=args.momentum if args.momentum else 0.0)
|
|
|
+ elif args.opt == 'adam':
|
|
|
+ optimizer = Adam(learning_rate=learning_rate)
|
|
|
+ else:
|
|
|
+ optimizer = Adam(learning_rate=learning_rate) # Default to Adam if unspecified
|
|
|
|
|
|
# Compile model
|
|
|
- model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
|
|
|
+ model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
|
|
|
|
|
|
# Check if a checkpoint exists and determine the initial_epoch
|
|
|
latest_checkpoint = tf.train.latest_checkpoint(args.output_dir)
|
|
@@ -119,16 +127,6 @@ def get_args_parser(add_help=True):
|
|
|
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
|
|
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
|
|
|
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
|
|
- parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
|
|
|
- parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
|
|
- parser.add_argument(
|
|
|
- "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
|
|
|
- )
|
|
|
- parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
|
|
- parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
|
|
- parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
|
|
- parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
|
|
|
- parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|