train.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. import datetime
  2. import os
  3. import time
  4. import warnings
  5. import presets
  6. import torch
  7. import torch.utils.data
  8. import torchvision
  9. import torchvision.transforms
  10. import utils
  11. from sampler import RASampler
  12. from torch import nn
  13. from torch.utils.data.dataloader import default_collate
  14. from torchvision.transforms.functional import InterpolationMode
  15. from transforms import get_mixup_cutmix
  16. def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
  17. model.train()
  18. metric_logger = utils.MetricLogger(delimiter=" ")
  19. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
  20. metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
  21. header = f"Epoch: [{epoch}]"
  22. for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
  23. start_time = time.time()
  24. image, target = image.to(device), target.to(device)
  25. with torch.cuda.amp.autocast(enabled=scaler is not None):
  26. output = model(image)
  27. loss = criterion(output, target)
  28. optimizer.zero_grad()
  29. if scaler is not None:
  30. scaler.scale(loss).backward()
  31. if args.clip_grad_norm is not None:
  32. # we should unscale the gradients of optimizer's assigned params if do gradient clipping
  33. scaler.unscale_(optimizer)
  34. nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
  35. scaler.step(optimizer)
  36. scaler.update()
  37. else:
  38. loss.backward()
  39. if args.clip_grad_norm is not None:
  40. nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
  41. optimizer.step()
  42. if model_ema and i % args.model_ema_steps == 0:
  43. model_ema.update_parameters(model)
  44. if epoch < args.lr_warmup_epochs:
  45. # Reset ema buffer to keep copying weights during warmup period
  46. model_ema.n_averaged.fill_(0)
  47. acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
  48. batch_size = image.shape[0]
  49. metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
  50. metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
  51. metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
  52. metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
  53. def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
  54. model.eval()
  55. metric_logger = utils.MetricLogger(delimiter=" ")
  56. header = f"Test: {log_suffix}"
  57. num_processed_samples = 0
  58. with torch.inference_mode():
  59. for image, target in metric_logger.log_every(data_loader, print_freq, header):
  60. image = image.to(device, non_blocking=True)
  61. target = target.to(device, non_blocking=True)
  62. output = model(image)
  63. loss = criterion(output, target)
  64. acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
  65. # FIXME need to take into account that the datasets
  66. # could have been padded in distributed setup
  67. batch_size = image.shape[0]
  68. metric_logger.update(loss=loss.item())
  69. metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
  70. metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
  71. num_processed_samples += batch_size
  72. # gather the stats from all processes
  73. num_processed_samples = utils.reduce_across_processes(num_processed_samples)
  74. if (
  75. hasattr(data_loader.dataset, "__len__")
  76. and len(data_loader.dataset) != num_processed_samples
  77. and torch.distributed.get_rank() == 0
  78. ):
  79. # See FIXME above
  80. warnings.warn(
  81. f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
  82. "samples were used for the validation, which might bias the results. "
  83. "Try adjusting the batch size and / or the world size. "
  84. "Setting the world size to 1 is always a safe bet."
  85. )
  86. metric_logger.synchronize_between_processes()
  87. print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
  88. return metric_logger.acc1.global_avg
  89. def _get_cache_path(filepath):
  90. import hashlib
  91. h = hashlib.sha1(filepath.encode()).hexdigest()
  92. cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
  93. cache_path = os.path.expanduser(cache_path)
  94. return cache_path
  95. def load_data(traindir, valdir, args):
  96. # Data loading code
  97. print("Loading data")
  98. val_resize_size, val_crop_size, train_crop_size = (
  99. args.val_resize_size,
  100. args.val_crop_size,
  101. args.train_crop_size,
  102. )
  103. interpolation = InterpolationMode(args.interpolation)
  104. print("Loading training data")
  105. st = time.time()
  106. cache_path = _get_cache_path(traindir)
  107. if args.cache_dataset and os.path.exists(cache_path):
  108. # Attention, as the transforms are also cached!
  109. print(f"Loading dataset_train from {cache_path}")
  110. # TODO: this could probably be weights_only=True
  111. dataset, _ = torch.load(cache_path, weights_only=False)
  112. else:
  113. # We need a default value for the variables below because args may come
  114. # from train_quantization.py which doesn't define them.
  115. auto_augment_policy = getattr(args, "auto_augment", None)
  116. random_erase_prob = getattr(args, "random_erase", 0.0)
  117. ra_magnitude = getattr(args, "ra_magnitude", None)
  118. augmix_severity = getattr(args, "augmix_severity", None)
  119. dataset = torchvision.datasets.ImageFolder(
  120. traindir,
  121. presets.ClassificationPresetTrain(
  122. crop_size=train_crop_size,
  123. interpolation=interpolation,
  124. auto_augment_policy=auto_augment_policy,
  125. random_erase_prob=random_erase_prob,
  126. ra_magnitude=ra_magnitude,
  127. augmix_severity=augmix_severity,
  128. backend=args.backend,
  129. use_v2=args.use_v2,
  130. ),
  131. )
  132. if args.cache_dataset:
  133. print(f"Saving dataset_train to {cache_path}")
  134. utils.mkdir(os.path.dirname(cache_path))
  135. utils.save_on_master((dataset, traindir), cache_path)
  136. print("Took", time.time() - st)
  137. print("Loading validation data")
  138. cache_path = _get_cache_path(valdir)
  139. if args.cache_dataset and os.path.exists(cache_path):
  140. # Attention, as the transforms are also cached!
  141. print(f"Loading dataset_test from {cache_path}")
  142. # TODO: this could probably be weights_only=True
  143. dataset_test, _ = torch.load(cache_path, weights_only=False)
  144. else:
  145. if args.weights and args.test_only:
  146. weights = torchvision.models.get_weight(args.weights)
  147. preprocessing = weights.transforms(antialias=True)
  148. if args.backend == "tensor":
  149. preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
  150. else:
  151. preprocessing = presets.ClassificationPresetEval(
  152. crop_size=val_crop_size,
  153. resize_size=val_resize_size,
  154. interpolation=interpolation,
  155. backend=args.backend,
  156. use_v2=args.use_v2,
  157. )
  158. dataset_test = torchvision.datasets.ImageFolder(
  159. valdir,
  160. preprocessing,
  161. )
  162. if args.cache_dataset:
  163. print(f"Saving dataset_test to {cache_path}")
  164. utils.mkdir(os.path.dirname(cache_path))
  165. utils.save_on_master((dataset_test, valdir), cache_path)
  166. print("Creating data loaders")
  167. if args.distributed:
  168. if hasattr(args, "ra_sampler") and args.ra_sampler:
  169. train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
  170. else:
  171. train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
  172. test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
  173. else:
  174. train_sampler = torch.utils.data.RandomSampler(dataset)
  175. test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  176. return dataset, dataset_test, train_sampler, test_sampler
  177. def main(args):
  178. if args.output_dir:
  179. utils.mkdir(args.output_dir)
  180. utils.init_distributed_mode(args)
  181. print(args)
  182. device = torch.device(args.device)
  183. if args.use_deterministic_algorithms:
  184. torch.backends.cudnn.benchmark = False
  185. torch.use_deterministic_algorithms(True)
  186. else:
  187. torch.backends.cudnn.benchmark = True
  188. train_dir = os.path.join(args.data_path, "train")
  189. val_dir = os.path.join(args.data_path, "val")
  190. dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
  191. num_classes = len(dataset.classes)
  192. mixup_cutmix = get_mixup_cutmix(
  193. mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2
  194. )
  195. if mixup_cutmix is not None:
  196. def collate_fn(batch):
  197. return mixup_cutmix(*default_collate(batch))
  198. else:
  199. collate_fn = default_collate
  200. data_loader = torch.utils.data.DataLoader(
  201. dataset,
  202. batch_size=args.batch_size,
  203. sampler=train_sampler,
  204. num_workers=args.workers,
  205. pin_memory=True,
  206. collate_fn=collate_fn,
  207. )
  208. data_loader_test = torch.utils.data.DataLoader(
  209. dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
  210. )
  211. print("Creating model")
  212. model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
  213. model.to(device)
  214. if args.distributed and args.sync_bn:
  215. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  216. criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
  217. custom_keys_weight_decay = []
  218. if args.bias_weight_decay is not None:
  219. custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
  220. if args.transformer_embedding_decay is not None:
  221. for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
  222. custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
  223. parameters = utils.set_weight_decay(
  224. model,
  225. args.weight_decay,
  226. norm_weight_decay=args.norm_weight_decay,
  227. custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
  228. )
  229. opt_name = args.opt.lower()
  230. if opt_name.startswith("sgd"):
  231. optimizer = torch.optim.SGD(
  232. parameters,
  233. lr=args.lr,
  234. momentum=args.momentum,
  235. weight_decay=args.weight_decay,
  236. nesterov="nesterov" in opt_name,
  237. )
  238. elif opt_name == "rmsprop":
  239. optimizer = torch.optim.RMSprop(
  240. parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
  241. )
  242. elif opt_name == "adamw":
  243. optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
  244. else:
  245. raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
  246. scaler = torch.cuda.amp.GradScaler() if args.amp else None
  247. args.lr_scheduler = args.lr_scheduler.lower()
  248. if args.lr_scheduler == "steplr":
  249. main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
  250. elif args.lr_scheduler == "cosineannealinglr":
  251. main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
  252. optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
  253. )
  254. elif args.lr_scheduler == "exponentiallr":
  255. main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
  256. else:
  257. raise RuntimeError(
  258. f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
  259. "are supported."
  260. )
  261. if args.lr_warmup_epochs > 0:
  262. if args.lr_warmup_method == "linear":
  263. warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  264. optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
  265. )
  266. elif args.lr_warmup_method == "constant":
  267. warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
  268. optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
  269. )
  270. else:
  271. raise RuntimeError(
  272. f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
  273. )
  274. lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
  275. optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
  276. )
  277. else:
  278. lr_scheduler = main_lr_scheduler
  279. model_without_ddp = model
  280. if args.distributed:
  281. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  282. model_without_ddp = model.module
  283. model_ema = None
  284. if args.model_ema:
  285. # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
  286. # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
  287. #
  288. # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
  289. # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
  290. # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
  291. adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
  292. alpha = 1.0 - args.model_ema_decay
  293. alpha = min(1.0, alpha * adjust)
  294. model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
  295. if args.resume:
  296. checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
  297. model_without_ddp.load_state_dict(checkpoint["model"])
  298. if not args.test_only:
  299. optimizer.load_state_dict(checkpoint["optimizer"])
  300. lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
  301. args.start_epoch = checkpoint["epoch"] + 1
  302. if model_ema:
  303. model_ema.load_state_dict(checkpoint["model_ema"])
  304. if scaler:
  305. scaler.load_state_dict(checkpoint["scaler"])
  306. if args.test_only:
  307. # We disable the cudnn benchmarking because it can noticeably affect the accuracy
  308. torch.backends.cudnn.benchmark = False
  309. torch.backends.cudnn.deterministic = True
  310. if model_ema:
  311. evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
  312. else:
  313. evaluate(model, criterion, data_loader_test, device=device)
  314. return
  315. print("Start training")
  316. start_time = time.time()
  317. for epoch in range(args.start_epoch, args.epochs):
  318. if args.distributed:
  319. train_sampler.set_epoch(epoch)
  320. train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
  321. lr_scheduler.step()
  322. evaluate(model, criterion, data_loader_test, device=device)
  323. if model_ema:
  324. evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
  325. if args.output_dir:
  326. checkpoint = {
  327. "model": model_without_ddp.state_dict(),
  328. "optimizer": optimizer.state_dict(),
  329. "lr_scheduler": lr_scheduler.state_dict(),
  330. "epoch": epoch,
  331. "args": args,
  332. }
  333. if model_ema:
  334. checkpoint["model_ema"] = model_ema.state_dict()
  335. if scaler:
  336. checkpoint["scaler"] = scaler.state_dict()
  337. utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
  338. utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
  339. total_time = time.time() - start_time
  340. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  341. print(f"Training time {total_time_str}")
  342. def get_args_parser(add_help=True):
  343. import argparse
  344. parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
  345. parser.add_argument("--data-path", default="dataset/CIFAR-10", type=str, help="dataset path")
  346. parser.add_argument("--model", default="resnet18", type=str, help="model name")
  347. parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
  348. parser.add_argument(
  349. "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  350. )
  351. parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
  352. parser.add_argument(
  353. "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
  354. )
  355. parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
  356. parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
  357. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  358. parser.add_argument(
  359. "--wd",
  360. "--weight-decay",
  361. default=1e-4,
  362. type=float,
  363. metavar="W",
  364. help="weight decay (default: 1e-4)",
  365. dest="weight_decay",
  366. )
  367. parser.add_argument(
  368. "--norm-weight-decay",
  369. default=None,
  370. type=float,
  371. help="weight decay for Normalization layers (default: None, same value as --wd)",
  372. )
  373. parser.add_argument(
  374. "--bias-weight-decay",
  375. default=None,
  376. type=float,
  377. help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
  378. )
  379. parser.add_argument(
  380. "--transformer-embedding-decay",
  381. default=None,
  382. type=float,
  383. help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
  384. )
  385. parser.add_argument(
  386. "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
  387. )
  388. parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
  389. parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
  390. parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
  391. parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
  392. parser.add_argument(
  393. "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
  394. )
  395. parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
  396. parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
  397. parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
  398. parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
  399. parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
  400. parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
  401. parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
  402. parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
  403. parser.add_argument(
  404. "--cache-dataset",
  405. dest="cache_dataset",
  406. help="Cache the datasets for quicker initialization. It also serializes the transforms",
  407. action="store_true",
  408. )
  409. parser.add_argument(
  410. "--sync-bn",
  411. dest="sync_bn",
  412. help="Use sync batch norm",
  413. action="store_true",
  414. )
  415. parser.add_argument(
  416. "--test-only",
  417. dest="test_only",
  418. help="Only test the model",
  419. action="store_true",
  420. )
  421. parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
  422. parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
  423. parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
  424. parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
  425. # Mixed precision training parameters
  426. parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
  427. # distributed training parameters
  428. parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
  429. parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
  430. parser.add_argument(
  431. "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
  432. )
  433. parser.add_argument(
  434. "--model-ema-steps",
  435. type=int,
  436. default=32,
  437. help="the number of iterations that controls how often to update the EMA model (default: 32)",
  438. )
  439. parser.add_argument(
  440. "--model-ema-decay",
  441. type=float,
  442. default=0.99998,
  443. help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
  444. )
  445. parser.add_argument(
  446. "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
  447. )
  448. parser.add_argument(
  449. "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
  450. )
  451. parser.add_argument(
  452. "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
  453. )
  454. parser.add_argument(
  455. "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
  456. )
  457. parser.add_argument(
  458. "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
  459. )
  460. parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
  461. parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
  462. parser.add_argument(
  463. "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
  464. )
  465. parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
  466. parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
  467. parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
  468. return parser
  469. if __name__ == "__main__":
  470. args = get_args_parser().parse_args()
  471. main(args)