Pārlūkot izejas kodu

修改文件缩进格式

liyan 7 mēneši atpakaļ
vecāks
revīzija
09ae8c89da
2 mainītis faili ar 9 papildinājumiem un 6 dzēšanām
  1. 5 4
      train_alexnet.py
  2. 4 2
      train_vgg16.py

+ 5 - 4
train_alexnet.py

@@ -7,7 +7,6 @@ from models.AlexNet import create_model
 from tensorflow.keras.preprocessing import image_dataset_from_directory
 
 
-
 def load_data(train_dir, val_dir, img_size=(224, 224), batch_size=32):
     def augment(image):
         # Random horizontal flip
@@ -78,7 +77,7 @@ def train_model(args, train_data, val_data):
     # Select optimizer based on args.opt
     if args.opt == 'sgd':
         optimizer = SGD(learning_rate=learning_rate,
-                                            momentum=args.momentum if args.momentum else 0.0)
+                        momentum=args.momentum if args.momentum else 0.0)
     elif args.opt == 'adam':
         optimizer = Adam(learning_rate=learning_rate)
     else:
@@ -143,6 +142,7 @@ def get_args_parser(add_help=True):
     )
     return parser
 
+
 if __name__ == "__main__":
     args = get_args_parser().parse_args()
 
@@ -154,7 +154,8 @@ if __name__ == "__main__":
     os.makedirs(args.output_dir, exist_ok=True)
 
     # Load data
-    train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
+    train_data, val_data = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size),
+                                     batch_size=args.batch_size)
 
     # Start training
-    train_model(args, train_data, val_data)
+    train_model(args, train_data, val_data)

+ 4 - 2
train_vgg16.py

@@ -140,6 +140,7 @@ def get_args_parser(add_help=True):
     )
     return parser
 
+
 if __name__ == "__main__":
     args = get_args_parser().parse_args()
 
@@ -151,7 +152,8 @@ if __name__ == "__main__":
     os.makedirs(args.output_dir, exist_ok=True)
 
     # Load data
-    train_generator, val_generator = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size), batch_size=args.batch_size)
+    train_generator, val_generator = load_data(train_dir, val_dir, img_size=(args.input_size, args.input_size),
+                                               batch_size=args.batch_size)
 
     # Start training
-    train_model(args, train_generator, val_generator)
+    train_model(args, train_generator, val_generator)