Browse Source

训练过程使用参数控制优化器、学习率等训练参数

liyan 7 months ago
parent
commit
35b7cbb676
2 changed files with 20 additions and 23 deletions
  1. 10 12
      train_alexnet.py
  2. 10 11
      train_vgg16.py

+ 10 - 12
train_alexnet.py

@@ -67,10 +67,18 @@ def train_model(args, train_data, val_data):
 
 
     # 调整学习率
     # 调整学习率
     learning_rate = args.lr if args.lr else 1e-2
     learning_rate = args.lr if args.lr else 1e-2
-    # optimizer = SGD(learning_rate=learning_rate, momentum=args.momentum)
+
+    # Select optimizer based on args.opt
+    if args.opt == 'sgd':
+        optimizer = SGD(learning_rate=learning_rate,
+                                            momentum=args.momentum if args.momentum else 0.0)
+    elif args.opt == 'adam':
+        optimizer = Adam(learning_rate=learning_rate)
+    else:
+        optimizer = Adam(learning_rate=learning_rate)  # Default to Adam if unspecified
 
 
     # Compile model
     # Compile model
-    model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
+    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
 
 
     # Check if a checkpoint exists and determine the initial_epoch
     # Check if a checkpoint exists and determine the initial_epoch
     latest_checkpoint = tf.train.latest_checkpoint(args.output_dir)
     latest_checkpoint = tf.train.latest_checkpoint(args.output_dir)
@@ -119,16 +127,6 @@ def get_args_parser(add_help=True):
     parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
     parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
     parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
     parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
     parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
     parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
-    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
-    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
-    parser.add_argument(
-        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
-    )
-    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
-    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
-    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
-    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
-    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
 
 
     parser.add_argument(
     parser.add_argument(
         "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
         "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"

+ 10 - 11
train_vgg16.py

@@ -46,8 +46,17 @@ def train_model(args, train_generator, val_generator):
     # 调整学习率
     # 调整学习率
     learning_rate = args.lr if args.lr else 1e-2
     learning_rate = args.lr if args.lr else 1e-2
 
 
+    # Select optimizer based on args.opt
+    if args.opt == 'sgd':
+        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
+                                            momentum=args.momentum if args.momentum else 0.0)
+    elif args.opt == 'adam':
+        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
+    else:
+        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # Default to Adam if unspecified
+
     # 编译模型
     # 编译模型
-    model.compile(optimizer=tf.keras.optimizers.Adam(),
+    model.compile(optimizer=optimizer,
                   loss='categorical_crossentropy',
                   loss='categorical_crossentropy',
                   metrics=['accuracy'])
                   metrics=['accuracy'])
 
 
@@ -91,16 +100,6 @@ def get_args_parser(add_help=True):
     parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
     parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
     parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
     parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
     parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
     parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
-    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
-    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
-    parser.add_argument(
-        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
-    )
-    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
-    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
-    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
-    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
-    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
 
 
     parser.add_argument(
     parser.add_argument(
         "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"
         "--input-size", default=224, type=int, help="the random crop size used for training (default: 224)"