Explorar o código

初始化项目

liyan hai 8 meses
achega
742a3ae79d
Modificáronse 9 ficheiros con 2025 adicións e 0 borrados
  1. 4 0
      .gitignore
  2. 368 0
      README.md
  3. 1 0
      dataset/README.md
  4. 119 0
      presets.py
  5. 62 0
      sampler.py
  6. 528 0
      train.py
  7. 273 0
      train_quantization.py
  8. 206 0
      transforms.py
  9. 464 0
      utils.py

+ 4 - 0
.gitignore

@@ -0,0 +1,4 @@
+.idea
+__pycache__
+dataset/*
+!dataset/README.md

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 368 - 0
README.md


+ 1 - 0
dataset/README.md

@@ -0,0 +1 @@
+dataset目录用于存放数据集

+ 119 - 0
presets.py

@@ -0,0 +1,119 @@
+import torch
+from torchvision.transforms.functional import InterpolationMode
+
+
+def get_module(use_v2):
+    # We need a protected import to avoid the V2 warning in case just V1 is used
+    if use_v2:
+        import torchvision.transforms.v2
+
+        return torchvision.transforms.v2
+    else:
+        import torchvision.transforms
+
+        return torchvision.transforms
+
+
+class ClassificationPresetTrain:
+    # Note: this transform assumes that the input to forward() are always PIL
+    # images, regardless of the backend parameter. We may change that in the
+    # future though, if we change the output type from the dataset.
+    def __init__(
+        self,
+        *,
+        crop_size,
+        mean=(0.485, 0.456, 0.406),
+        std=(0.229, 0.224, 0.225),
+        interpolation=InterpolationMode.BILINEAR,
+        hflip_prob=0.5,
+        auto_augment_policy=None,
+        ra_magnitude=9,
+        augmix_severity=3,
+        random_erase_prob=0.0,
+        backend="pil",
+        use_v2=False,
+    ):
+        T = get_module(use_v2)
+
+        transforms = []
+        backend = backend.lower()
+        if backend == "tensor":
+            transforms.append(T.PILToTensor())
+        elif backend != "pil":
+            raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
+
+        transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
+        if hflip_prob > 0:
+            transforms.append(T.RandomHorizontalFlip(hflip_prob))
+        if auto_augment_policy is not None:
+            if auto_augment_policy == "ra":
+                transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
+            elif auto_augment_policy == "ta_wide":
+                transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
+            elif auto_augment_policy == "augmix":
+                transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
+            else:
+                aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
+                transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))
+
+        if backend == "pil":
+            transforms.append(T.PILToTensor())
+
+        transforms.extend(
+            [
+                T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
+                T.Normalize(mean=mean, std=std),
+            ]
+        )
+        if random_erase_prob > 0:
+            transforms.append(T.RandomErasing(p=random_erase_prob))
+
+        if use_v2:
+            transforms.append(T.ToPureTensor())
+
+        self.transforms = T.Compose(transforms)
+
+    def __call__(self, img):
+        return self.transforms(img)
+
+
+class ClassificationPresetEval:
+    def __init__(
+        self,
+        *,
+        crop_size,
+        resize_size=256,
+        mean=(0.485, 0.456, 0.406),
+        std=(0.229, 0.224, 0.225),
+        interpolation=InterpolationMode.BILINEAR,
+        backend="pil",
+        use_v2=False,
+    ):
+        T = get_module(use_v2)
+        transforms = []
+        backend = backend.lower()
+        if backend == "tensor":
+            transforms.append(T.PILToTensor())
+        elif backend != "pil":
+            raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
+
+        transforms += [
+            T.Resize(resize_size, interpolation=interpolation, antialias=True),
+            T.CenterCrop(crop_size),
+        ]
+
+        if backend == "pil":
+            transforms.append(T.PILToTensor())
+
+        transforms += [
+            T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
+            T.Normalize(mean=mean, std=std),
+        ]
+
+        if use_v2:
+            transforms.append(T.ToPureTensor())
+
+        self.transforms = T.Compose(transforms)
+
+    def __call__(self, img):
+        return self.transforms(img)

+ 62 - 0
sampler.py

@@ -0,0 +1,62 @@
+import math
+
+import torch
+import torch.distributed as dist
+
+
+class RASampler(torch.utils.data.Sampler):
+    """Sampler that restricts data loading to a subset of the dataset for distributed,
+    with repeated augmentation.
+    It ensures that different each augmented version of a sample will be visible to a
+    different process (GPU).
+    Heavily based on 'torch.utils.data.DistributedSampler'.
+
+    This is borrowed from the DeiT Repo:
+    https://github.com/facebookresearch/deit/blob/main/samplers.py
+    """
+
+    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available!")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available!")
+            rank = dist.get_rank()
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
+        self.total_size = self.num_samples * self.num_replicas
+        self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
+        self.shuffle = shuffle
+        self.seed = seed
+        self.repetitions = repetitions
+
+    def __iter__(self):
+        if self.shuffle:
+            # Deterministically shuffle based on epoch
+            g = torch.Generator()
+            g.manual_seed(self.seed + self.epoch)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+
+        # Add extra samples to make it evenly divisible
+        indices = [ele for ele in indices for i in range(self.repetitions)]
+        indices += indices[: (self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        # Subsample
+        indices = indices[self.rank : self.total_size : self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices[: self.num_selected_samples])
+
+    def __len__(self):
+        return self.num_selected_samples
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch

+ 528 - 0
train.py

@@ -0,0 +1,528 @@
+import datetime
+import os
+import time
+import warnings
+
+import presets
+import torch
+import torch.utils.data
+import torchvision
+import torchvision.transforms
+import utils
+from sampler import RASampler
+from torch import nn
+from torch.utils.data.dataloader import default_collate
+from torchvision.transforms.functional import InterpolationMode
+from transforms import get_mixup_cutmix
+
+
+def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
+    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
+
+    header = f"Epoch: [{epoch}]"
+    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
+        start_time = time.time()
+        image, target = image.to(device), target.to(device)
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            output = model(image)
+            loss = criterion(output, target)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(loss).backward()
+            if args.clip_grad_norm is not None:
+                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
+                scaler.unscale_(optimizer)
+                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            loss.backward()
+            if args.clip_grad_norm is not None:
+                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
+            optimizer.step()
+
+        if model_ema and i % args.model_ema_steps == 0:
+            model_ema.update_parameters(model)
+            if epoch < args.lr_warmup_epochs:
+                # Reset ema buffer to keep copying weights during warmup period
+                model_ema.n_averaged.fill_(0)
+
+        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
+        batch_size = image.shape[0]
+        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
+        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
+        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
+        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
+
+
+def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
+    model.eval()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    header = f"Test: {log_suffix}"
+
+    num_processed_samples = 0
+    with torch.inference_mode():
+        for image, target in metric_logger.log_every(data_loader, print_freq, header):
+            image = image.to(device, non_blocking=True)
+            target = target.to(device, non_blocking=True)
+            output = model(image)
+            loss = criterion(output, target)
+
+            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
+            # FIXME need to take into account that the datasets
+            # could have been padded in distributed setup
+            batch_size = image.shape[0]
+            metric_logger.update(loss=loss.item())
+            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
+            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
+            num_processed_samples += batch_size
+    # gather the stats from all processes
+
+    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
+    if (
+        hasattr(data_loader.dataset, "__len__")
+        and len(data_loader.dataset) != num_processed_samples
+        and torch.distributed.get_rank() == 0
+    ):
+        # See FIXME above
+        warnings.warn(
+            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
+            "samples were used for the validation, which might bias the results. "
+            "Try adjusting the batch size and / or the world size. "
+            "Setting the world size to 1 is always a safe bet."
+        )
+
+    metric_logger.synchronize_between_processes()
+
+    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
+    return metric_logger.acc1.global_avg
+
+
+def _get_cache_path(filepath):
+    import hashlib
+
+    h = hashlib.sha1(filepath.encode()).hexdigest()
+    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
+    cache_path = os.path.expanduser(cache_path)
+    return cache_path
+
+
+def load_data(traindir, valdir, args):
+    # Data loading code
+    print("Loading data")
+    val_resize_size, val_crop_size, train_crop_size = (
+        args.val_resize_size,
+        args.val_crop_size,
+        args.train_crop_size,
+    )
+    interpolation = InterpolationMode(args.interpolation)
+
+    print("Loading training data")
+    st = time.time()
+    cache_path = _get_cache_path(traindir)
+    if args.cache_dataset and os.path.exists(cache_path):
+        # Attention, as the transforms are also cached!
+        print(f"Loading dataset_train from {cache_path}")
+        # TODO: this could probably be weights_only=True
+        dataset, _ = torch.load(cache_path, weights_only=False)
+    else:
+        # We need a default value for the variables below because args may come
+        # from train_quantization.py which doesn't define them.
+        auto_augment_policy = getattr(args, "auto_augment", None)
+        random_erase_prob = getattr(args, "random_erase", 0.0)
+        ra_magnitude = getattr(args, "ra_magnitude", None)
+        augmix_severity = getattr(args, "augmix_severity", None)
+        dataset = torchvision.datasets.ImageFolder(
+            traindir,
+            presets.ClassificationPresetTrain(
+                crop_size=train_crop_size,
+                interpolation=interpolation,
+                auto_augment_policy=auto_augment_policy,
+                random_erase_prob=random_erase_prob,
+                ra_magnitude=ra_magnitude,
+                augmix_severity=augmix_severity,
+                backend=args.backend,
+                use_v2=args.use_v2,
+            ),
+        )
+        if args.cache_dataset:
+            print(f"Saving dataset_train to {cache_path}")
+            utils.mkdir(os.path.dirname(cache_path))
+            utils.save_on_master((dataset, traindir), cache_path)
+    print("Took", time.time() - st)
+
+    print("Loading validation data")
+    cache_path = _get_cache_path(valdir)
+    if args.cache_dataset and os.path.exists(cache_path):
+        # Attention, as the transforms are also cached!
+        print(f"Loading dataset_test from {cache_path}")
+        # TODO: this could probably be weights_only=True
+        dataset_test, _ = torch.load(cache_path, weights_only=False)
+    else:
+        if args.weights and args.test_only:
+            weights = torchvision.models.get_weight(args.weights)
+            preprocessing = weights.transforms(antialias=True)
+            if args.backend == "tensor":
+                preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
+
+        else:
+            preprocessing = presets.ClassificationPresetEval(
+                crop_size=val_crop_size,
+                resize_size=val_resize_size,
+                interpolation=interpolation,
+                backend=args.backend,
+                use_v2=args.use_v2,
+            )
+
+        dataset_test = torchvision.datasets.ImageFolder(
+            valdir,
+            preprocessing,
+        )
+        if args.cache_dataset:
+            print(f"Saving dataset_test to {cache_path}")
+            utils.mkdir(os.path.dirname(cache_path))
+            utils.save_on_master((dataset_test, valdir), cache_path)
+
+    print("Creating data loaders")
+    if args.distributed:
+        if hasattr(args, "ra_sampler") and args.ra_sampler:
+            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
+        else:
+            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
+    else:
+        train_sampler = torch.utils.data.RandomSampler(dataset)
+        test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+
+    return dataset, dataset_test, train_sampler, test_sampler
+
+
+def main(args):
+    if args.output_dir:
+        utils.mkdir(args.output_dir)
+
+    utils.init_distributed_mode(args)
+    print(args)
+
+    device = torch.device(args.device)
+
+    if args.use_deterministic_algorithms:
+        torch.backends.cudnn.benchmark = False
+        torch.use_deterministic_algorithms(True)
+    else:
+        torch.backends.cudnn.benchmark = True
+
+    train_dir = os.path.join(args.data_path, "train")
+    val_dir = os.path.join(args.data_path, "val")
+    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
+
+    num_classes = len(dataset.classes)
+    mixup_cutmix = get_mixup_cutmix(
+        mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2
+    )
+    if mixup_cutmix is not None:
+
+        def collate_fn(batch):
+            return mixup_cutmix(*default_collate(batch))
+
+    else:
+        collate_fn = default_collate
+
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        sampler=train_sampler,
+        num_workers=args.workers,
+        pin_memory=True,
+        collate_fn=collate_fn,
+    )
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
+    )
+
+    print("Creating model")
+    model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
+    model.to(device)
+
+    if args.distributed and args.sync_bn:
+        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
+    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
+
+    custom_keys_weight_decay = []
+    if args.bias_weight_decay is not None:
+        custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
+    if args.transformer_embedding_decay is not None:
+        for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
+            custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
+    parameters = utils.set_weight_decay(
+        model,
+        args.weight_decay,
+        norm_weight_decay=args.norm_weight_decay,
+        custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
+    )
+
+    opt_name = args.opt.lower()
+    if opt_name.startswith("sgd"):
+        optimizer = torch.optim.SGD(
+            parameters,
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+            nesterov="nesterov" in opt_name,
+        )
+    elif opt_name == "rmsprop":
+        optimizer = torch.optim.RMSprop(
+            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
+        )
+    elif opt_name == "adamw":
+        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
+    else:
+        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
+
+    scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+    args.lr_scheduler = args.lr_scheduler.lower()
+    if args.lr_scheduler == "steplr":
+        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
+    elif args.lr_scheduler == "cosineannealinglr":
+        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+            optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
+        )
+    elif args.lr_scheduler == "exponentiallr":
+        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
+    else:
+        raise RuntimeError(
+            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
+            "are supported."
+        )
+
+    if args.lr_warmup_epochs > 0:
+        if args.lr_warmup_method == "linear":
+            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
+            )
+        elif args.lr_warmup_method == "constant":
+            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
+                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
+            )
+        else:
+            raise RuntimeError(
+                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
+            )
+        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
+            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
+        )
+    else:
+        lr_scheduler = main_lr_scheduler
+
+    model_without_ddp = model
+    if args.distributed:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
+
+    model_ema = None
+    if args.model_ema:
+        # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
+        # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
+        #
+        # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
+        # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
+        # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
+        adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
+        alpha = 1.0 - args.model_ema_decay
+        alpha = min(1.0, alpha * adjust)
+        model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
+
+    if args.resume:
+        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
+        model_without_ddp.load_state_dict(checkpoint["model"])
+        if not args.test_only:
+            optimizer.load_state_dict(checkpoint["optimizer"])
+            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
+        args.start_epoch = checkpoint["epoch"] + 1
+        if model_ema:
+            model_ema.load_state_dict(checkpoint["model_ema"])
+        if scaler:
+            scaler.load_state_dict(checkpoint["scaler"])
+
+    if args.test_only:
+        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
+        torch.backends.cudnn.benchmark = False
+        torch.backends.cudnn.deterministic = True
+        if model_ema:
+            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
+        else:
+            evaluate(model, criterion, data_loader_test, device=device)
+        return
+
+    print("Start training")
+    start_time = time.time()
+    for epoch in range(args.start_epoch, args.epochs):
+        if args.distributed:
+            train_sampler.set_epoch(epoch)
+        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
+        lr_scheduler.step()
+        evaluate(model, criterion, data_loader_test, device=device)
+        if model_ema:
+            evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
+        if args.output_dir:
+            checkpoint = {
+                "model": model_without_ddp.state_dict(),
+                "optimizer": optimizer.state_dict(),
+                "lr_scheduler": lr_scheduler.state_dict(),
+                "epoch": epoch,
+                "args": args,
+            }
+            if model_ema:
+                checkpoint["model_ema"] = model_ema.state_dict()
+            if scaler:
+                checkpoint["scaler"] = scaler.state_dict()
+            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
+            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print(f"Training time {total_time_str}")
+
+
+def get_args_parser(add_help=True):
+    import argparse
+
+    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
+
+    parser.add_argument("--data-path", default="dataset/CIFAR-10", type=str, help="dataset path")
+    parser.add_argument("--model", default="resnet18", type=str, help="model name")
+    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
+    parser.add_argument(
+        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
+    )
+    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
+    parser.add_argument(
+        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
+    )
+    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
+    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
+    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=1e-4,
+        type=float,
+        metavar="W",
+        help="weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "--norm-weight-decay",
+        default=None,
+        type=float,
+        help="weight decay for Normalization layers (default: None, same value as --wd)",
+    )
+    parser.add_argument(
+        "--bias-weight-decay",
+        default=None,
+        type=float,
+        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
+    )
+    parser.add_argument(
+        "--transformer-embedding-decay",
+        default=None,
+        type=float,
+        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
+    )
+    parser.add_argument(
+        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
+    )
+    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
+    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
+    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
+    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
+    parser.add_argument(
+        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
+    )
+    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
+    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
+    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
+    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
+    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
+    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
+    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
+    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
+    parser.add_argument(
+        "--cache-dataset",
+        dest="cache_dataset",
+        help="Cache the datasets for quicker initialization. It also serializes the transforms",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--sync-bn",
+        dest="sync_bn",
+        help="Use sync batch norm",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--test-only",
+        dest="test_only",
+        help="Only test the model",
+        action="store_true",
+    )
+    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
+    parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
+    parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
+    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
+
+    # Mixed precision training parameters
+    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
+
+    # distributed training parameters
+    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
+    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
+    parser.add_argument(
+        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
+    )
+    parser.add_argument(
+        "--model-ema-steps",
+        type=int,
+        default=32,
+        help="the number of iterations that controls how often to update the EMA model (default: 32)",
+    )
+    parser.add_argument(
+        "--model-ema-decay",
+        type=float,
+        default=0.99998,
+        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
+    )
+    parser.add_argument(
+        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
+    )
+    parser.add_argument(
+        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
+    )
+    parser.add_argument(
+        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
+    )
+    parser.add_argument(
+        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
+    )
+    parser.add_argument(
+        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
+    )
+    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
+    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
+    parser.add_argument(
+        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
+    )
+    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
+    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
+    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
+    return parser
+
+
+if __name__ == "__main__":
+    args = get_args_parser().parse_args()
+    main(args)

+ 273 - 0
train_quantization.py

@@ -0,0 +1,273 @@
+import copy
+import datetime
+import os
+import time
+
+import torch
+import torch.ao.quantization
+import torch.utils.data
+import torchvision
+import utils
+from torch import nn
+from train import evaluate, load_data, train_one_epoch
+
+
+def main(args):
+    if args.output_dir:
+        utils.mkdir(args.output_dir)
+
+    utils.init_distributed_mode(args)
+    print(args)
+
+    if args.post_training_quantize and args.distributed:
+        raise RuntimeError("Post training quantization example should not be performed on distributed mode")
+
+    # Set backend engine to ensure that quantized model runs on the correct kernels
+    if args.qbackend not in torch.backends.quantized.supported_engines:
+        raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
+    torch.backends.quantized.engine = args.qbackend
+
+    device = torch.device(args.device)
+    torch.backends.cudnn.benchmark = True
+
+    # Data loading code
+    print("Loading data")
+    train_dir = os.path.join(args.data_path, "train")
+    val_dir = os.path.join(args.data_path, "val")
+
+    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
+    )
+
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
+    )
+
+    print("Creating model", args.model)
+    # when training quantized models, we always start from a pre-trained fp32 reference model
+    prefix = "quantized_"
+    model_name = args.model
+    if not model_name.startswith(prefix):
+        model_name = prefix + model_name
+    model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
+    model.to(device)
+
+    if not (args.test_only or args.post_training_quantize):
+        model.fuse_model(is_qat=True)
+        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
+        torch.ao.quantization.prepare_qat(model, inplace=True)
+
+        if args.distributed and args.sync_bn:
+            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
+        optimizer = torch.optim.SGD(
+            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
+        )
+
+        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
+
+    criterion = nn.CrossEntropyLoss()
+    model_without_ddp = model
+    if args.distributed:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
+
+    if args.resume:
+        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
+        model_without_ddp.load_state_dict(checkpoint["model"])
+        optimizer.load_state_dict(checkpoint["optimizer"])
+        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
+        args.start_epoch = checkpoint["epoch"] + 1
+
+    if args.post_training_quantize:
+        # perform calibration on a subset of the training dataset
+        # for that, create a subset of the training dataset
+        ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
+        data_loader_calibration = torch.utils.data.DataLoader(
+            ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
+        )
+        model.eval()
+        model.fuse_model(is_qat=False)
+        model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend)
+        torch.ao.quantization.prepare(model, inplace=True)
+        # Calibrate first
+        print("Calibrating")
+        evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
+        torch.ao.quantization.convert(model, inplace=True)
+        if args.output_dir:
+            print("Saving quantized model")
+            if utils.is_main_process():
+                torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
+        print("Evaluating post-training quantized model")
+        evaluate(model, criterion, data_loader_test, device=device)
+        return
+
+    if args.test_only:
+        evaluate(model, criterion, data_loader_test, device=device)
+        return
+
+    model.apply(torch.ao.quantization.enable_observer)
+    model.apply(torch.ao.quantization.enable_fake_quant)
+    start_time = time.time()
+    for epoch in range(args.start_epoch, args.epochs):
+        if args.distributed:
+            train_sampler.set_epoch(epoch)
+        print("Starting training for epoch", epoch)
+        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
+        lr_scheduler.step()
+        with torch.inference_mode():
+            if epoch >= args.num_observer_update_epochs:
+                print("Disabling observer for subseq epochs, epoch = ", epoch)
+                model.apply(torch.ao.quantization.disable_observer)
+            if epoch >= args.num_batch_norm_update_epochs:
+                print("Freezing BN for subseq epochs, epoch = ", epoch)
+                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+            print("Evaluate QAT model")
+
+            evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
+            quantized_eval_model = copy.deepcopy(model_without_ddp)
+            quantized_eval_model.eval()
+            quantized_eval_model.to(torch.device("cpu"))
+            torch.ao.quantization.convert(quantized_eval_model, inplace=True)
+
+            print("Evaluate Quantized model")
+            evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
+
+        model.train()
+
+        if args.output_dir:
+            checkpoint = {
+                "model": model_without_ddp.state_dict(),
+                "eval_model": quantized_eval_model.state_dict(),
+                "optimizer": optimizer.state_dict(),
+                "lr_scheduler": lr_scheduler.state_dict(),
+                "epoch": epoch,
+                "args": args,
+            }
+            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
+            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
+        print("Saving models after epoch ", epoch)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print(f"Training time {total_time_str}")
+
+
+def get_args_parser(add_help=True):
+    import argparse
+
+    parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
+
+    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
+    parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
+    parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack")
+    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
+
+    parser.add_argument(
+        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
+    )
+    parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
+    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
+    parser.add_argument(
+        "--num-observer-update-epochs",
+        default=4,
+        type=int,
+        metavar="N",
+        help="number of total epochs to update observers",
+    )
+    parser.add_argument(
+        "--num-batch-norm-update-epochs",
+        default=3,
+        type=int,
+        metavar="N",
+        help="number of total epochs to update batch norm stats",
+    )
+    parser.add_argument(
+        "--num-calibration-batches",
+        default=32,
+        type=int,
+        metavar="N",
+        help="number of batches of training set for \
+                              observer calibration ",
+    )
+
+    parser.add_argument(
+        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
+    )
+    parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
+    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=1e-4,
+        type=float,
+        metavar="W",
+        help="weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
+    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
+    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
+    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
+    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
+    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
+    parser.add_argument(
+        "--cache-dataset",
+        dest="cache_dataset",
+        help="Cache the datasets for quicker initialization. \
+             It also serializes the transforms",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--sync-bn",
+        dest="sync_bn",
+        help="Use sync batch norm",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--test-only",
+        dest="test_only",
+        help="Only test the model",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--post-training-quantize",
+        dest="post_training_quantize",
+        help="Post training quantize the model",
+        action="store_true",
+    )
+
+    # distributed training parameters
+    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
+    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
+
+    parser.add_argument(
+        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
+    )
+    parser.add_argument(
+        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
+    )
+    parser.add_argument(
+        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
+    )
+    parser.add_argument(
+        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
+    )
+    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
+    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
+
+    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
+    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
+
+    return parser
+
+
+if __name__ == "__main__":
+    args = get_args_parser().parse_args()
+    if args.backend in ("fbgemm", "qnnpack"):
+        raise ValueError(
+            "The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) "
+            "instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend."
+        )
+    main(args)

+ 206 - 0
transforms.py

@@ -0,0 +1,206 @@
+import math
+from typing import Tuple
+
+import torch
+from presets import get_module
+from torch import Tensor
+from torchvision.transforms import functional as F
+
+
+def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
+    transforms_module = get_module(use_v2)
+
+    mixup_cutmix = []
+    if mixup_alpha > 0:
+        mixup_cutmix.append(
+            transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
+            if use_v2
+            else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
+        )
+    if cutmix_alpha > 0:
+        mixup_cutmix.append(
+            transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes)
+            if use_v2
+            else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha)
+        )
+    if not mixup_cutmix:
+        return None
+
+    return transforms_module.RandomChoice(mixup_cutmix)
+
+
+class RandomMixUp(torch.nn.Module):
+    """Randomly apply MixUp to the provided batch and targets.
+    The class implements the data augmentations as described in the paper
+    `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
+
+    Args:
+        num_classes (int): number of classes used for one-hot encoding.
+        p (float): probability of the batch being transformed. Default value is 0.5.
+        alpha (float): hyperparameter of the Beta distribution used for mixup.
+            Default value is 1.0.
+        inplace (bool): boolean to make this transform inplace. Default set to False.
+    """
+
+    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
+        super().__init__()
+
+        if num_classes < 1:
+            raise ValueError(
+                f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
+            )
+
+        if alpha <= 0:
+            raise ValueError("Alpha param can't be zero.")
+
+        self.num_classes = num_classes
+        self.p = p
+        self.alpha = alpha
+        self.inplace = inplace
+
+    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+            batch (Tensor): Float tensor of size (B, C, H, W)
+            target (Tensor): Integer tensor of size (B, )
+
+        Returns:
+            Tensor: Randomly transformed batch.
+        """
+        if batch.ndim != 4:
+            raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
+        if target.ndim != 1:
+            raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
+        if not batch.is_floating_point():
+            raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
+        if target.dtype != torch.int64:
+            raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
+
+        if not self.inplace:
+            batch = batch.clone()
+            target = target.clone()
+
+        if target.ndim == 1:
+            target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
+
+        if torch.rand(1).item() >= self.p:
+            return batch, target
+
+        # It's faster to roll the batch by one instead of shuffling it to create image pairs
+        batch_rolled = batch.roll(1, 0)
+        target_rolled = target.roll(1, 0)
+
+        # Implemented as on mixup paper, page 3.
+        lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
+        batch_rolled.mul_(1.0 - lambda_param)
+        batch.mul_(lambda_param).add_(batch_rolled)
+
+        target_rolled.mul_(1.0 - lambda_param)
+        target.mul_(lambda_param).add_(target_rolled)
+
+        return batch, target
+
+    def __repr__(self) -> str:
+        s = (
+            f"{self.__class__.__name__}("
+            f"num_classes={self.num_classes}"
+            f", p={self.p}"
+            f", alpha={self.alpha}"
+            f", inplace={self.inplace}"
+            f")"
+        )
+        return s
+
+
+class RandomCutMix(torch.nn.Module):
+    """Randomly apply CutMix to the provided batch and targets.
+    The class implements the data augmentations as described in the paper
+    `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
+    <https://arxiv.org/abs/1905.04899>`_.
+
+    Args:
+        num_classes (int): number of classes used for one-hot encoding.
+        p (float): probability of the batch being transformed. Default value is 0.5.
+        alpha (float): hyperparameter of the Beta distribution used for cutmix.
+            Default value is 1.0.
+        inplace (bool): boolean to make this transform inplace. Default set to False.
+    """
+
+    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
+        super().__init__()
+        if num_classes < 1:
+            raise ValueError("Please provide a valid positive value for the num_classes.")
+        if alpha <= 0:
+            raise ValueError("Alpha param can't be zero.")
+
+        self.num_classes = num_classes
+        self.p = p
+        self.alpha = alpha
+        self.inplace = inplace
+
+    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+            batch (Tensor): Float tensor of size (B, C, H, W)
+            target (Tensor): Integer tensor of size (B, )
+
+        Returns:
+            Tensor: Randomly transformed batch.
+        """
+        if batch.ndim != 4:
+            raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
+        if target.ndim != 1:
+            raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
+        if not batch.is_floating_point():
+            raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
+        if target.dtype != torch.int64:
+            raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
+
+        if not self.inplace:
+            batch = batch.clone()
+            target = target.clone()
+
+        if target.ndim == 1:
+            target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
+
+        if torch.rand(1).item() >= self.p:
+            return batch, target
+
+        # It's faster to roll the batch by one instead of shuffling it to create image pairs
+        batch_rolled = batch.roll(1, 0)
+        target_rolled = target.roll(1, 0)
+
+        # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
+        lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
+        _, H, W = F.get_dimensions(batch)
+
+        r_x = torch.randint(W, (1,))
+        r_y = torch.randint(H, (1,))
+
+        r = 0.5 * math.sqrt(1.0 - lambda_param)
+        r_w_half = int(r * W)
+        r_h_half = int(r * H)
+
+        x1 = int(torch.clamp(r_x - r_w_half, min=0))
+        y1 = int(torch.clamp(r_y - r_h_half, min=0))
+        x2 = int(torch.clamp(r_x + r_w_half, max=W))
+        y2 = int(torch.clamp(r_y + r_h_half, max=H))
+
+        batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
+        lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
+
+        target_rolled.mul_(1.0 - lambda_param)
+        target.mul_(lambda_param).add_(target_rolled)
+
+        return batch, target
+
+    def __repr__(self) -> str:
+        s = (
+            f"{self.__class__.__name__}("
+            f"num_classes={self.num_classes}"
+            f", p={self.p}"
+            f", alpha={self.alpha}"
+            f", inplace={self.inplace}"
+            f")"
+        )
+        return s

+ 464 - 0
utils.py

@@ -0,0 +1,464 @@
+import copy
+import datetime
+import errno
+import hashlib
+import os
+import time
+from collections import defaultdict, deque, OrderedDict
+from typing import List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+
+
+class SmoothedValue:
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        t = reduce_across_processes([self.count, self.total])
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
+        )
+
+
+class MetricLogger:
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(f"{name}: {str(meter)}")
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ""
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt="{avg:.4f}")
+        data_time = SmoothedValue(fmt="{avg:.4f}")
+        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+        if torch.cuda.is_available():
+            log_msg = self.delimiter.join(
+                [
+                    header,
+                    "[{0" + space_fmt + "}/{1}]",
+                    "eta: {eta}",
+                    "{meters}",
+                    "time: {time}",
+                    "data: {data}",
+                    "max mem: {memory:.0f}",
+                ]
+            )
+        else:
+            log_msg = self.delimiter.join(
+                [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
+            )
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(
+                        log_msg.format(
+                            i,
+                            len(iterable),
+                            eta=eta_string,
+                            meters=str(self),
+                            time=str(iter_time),
+                            data=str(data_time),
+                            memory=torch.cuda.max_memory_allocated() / MB,
+                        )
+                    )
+                else:
+                    print(
+                        log_msg.format(
+                            i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
+                        )
+                    )
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print(f"{header} Total time: {total_time_str}")
+
+
+class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
+    """Maintains moving averages of model parameters using an exponential decay.
+    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
+    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
+    is used to compute the EMA.
+    """
+
+    def __init__(self, model, decay, device="cpu"):
+        def ema_avg(avg_model_param, model_param, num_averaged):
+            return decay * avg_model_param + (1 - decay) * model_param
+
+        super().__init__(model, device, ema_avg, use_buffers=True)
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.inference_mode():
+        maxk = max(topk)
+        batch_size = target.size(0)
+        if target.ndim == 2:
+            target = target.max(dim=1)[1]
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target[None])
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
+            res.append(correct_k * (100.0 / batch_size))
+        return res
+
+
+def mkdir(path):
+    try:
+        os.makedirs(path)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop("force", False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.gpu = int(os.environ["LOCAL_RANK"])
+    elif "SLURM_PROCID" in os.environ:
+        args.rank = int(os.environ["SLURM_PROCID"])
+        args.gpu = args.rank % torch.cuda.device_count()
+    elif hasattr(args, "rank"):
+        pass
+    else:
+        print("Not using distributed mode")
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = "nccl"
+    print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
+    torch.distributed.init_process_group(
+        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
+    )
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
+
+
+def average_checkpoints(inputs):
+    """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
+    https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
+
+    Args:
+      inputs (List[str]): An iterable of string paths of checkpoints to load from.
+    Returns:
+      A dict of string keys mapping to various values. The 'model' key
+      from the returned dict should correspond to an OrderedDict mapping
+      string parameter names to torch Tensors.
+    """
+    params_dict = OrderedDict()
+    params_keys = None
+    new_state = None
+    num_models = len(inputs)
+    for fpath in inputs:
+        with open(fpath, "rb") as f:
+            state = torch.load(
+                f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True
+            )
+        # Copies over the settings from the first checkpoint
+        if new_state is None:
+            new_state = state
+        model_params = state["model"]
+        model_params_keys = list(model_params.keys())
+        if params_keys is None:
+            params_keys = model_params_keys
+        elif params_keys != model_params_keys:
+            raise KeyError(
+                f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
+            )
+        for k in params_keys:
+            p = model_params[k]
+            if isinstance(p, torch.HalfTensor):
+                p = p.float()
+            if k not in params_dict:
+                params_dict[k] = p.clone()
+                # NOTE: clone() is needed in case of p is a shared parameter
+            else:
+                params_dict[k] += p
+    averaged_params = OrderedDict()
+    for k, v in params_dict.items():
+        averaged_params[k] = v
+        if averaged_params[k].is_floating_point():
+            averaged_params[k].div_(num_models)
+        else:
+            averaged_params[k] //= num_models
+    new_state["model"] = averaged_params
+    return new_state
+
+
+def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
+    """
+    This method can be used to prepare weights files for new models. It receives as
+    input a model architecture and a checkpoint from the training script and produces
+    a file with the weights ready for release.
+
+    Examples:
+        from torchvision import models as M
+
+        # Classification
+        model = M.mobilenet_v3_large(weights=None)
+        print(store_model_weights(model, './class.pth'))
+
+        # Quantized Classification
+        model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
+        model.fuse_model(is_qat=True)
+        model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
+        _ = torch.ao.quantization.prepare_qat(model, inplace=True)
+        print(store_model_weights(model, './qat.pth'))
+
+        # Object Detection
+        model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
+        print(store_model_weights(model, './obj.pth'))
+
+        # Segmentation
+        model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
+        print(store_model_weights(model, './segm.pth', strict=False))
+
+    Args:
+        model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
+        checkpoint_path (str): The path of the checkpoint we will load.
+        checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
+            Default: "model".
+        strict (bool): whether to strictly enforce that the keys
+            in :attr:`state_dict` match the keys returned by this module's
+            :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
+
+    Returns:
+        output_path (str): The location where the weights are saved.
+    """
+    # Store the new model next to the checkpoint_path
+    checkpoint_path = os.path.abspath(checkpoint_path)
+    output_dir = os.path.dirname(checkpoint_path)
+
+    # Deep copy to avoid side effects on the model object.
+    model = copy.deepcopy(model)
+    checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
+
+    # Load the weights to the model to validate that everything works
+    # and remove unnecessary weights (such as auxiliaries, etc.)
+    if checkpoint_key == "model_ema":
+        del checkpoint[checkpoint_key]["n_averaged"]
+        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
+    model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
+
+    tmp_path = os.path.join(output_dir, str(model.__hash__()))
+    torch.save(model.state_dict(), tmp_path)
+
+    sha256_hash = hashlib.sha256()
+    with open(tmp_path, "rb") as f:
+        # Read and update hash string value in blocks of 4K
+        for byte_block in iter(lambda: f.read(4096), b""):
+            sha256_hash.update(byte_block)
+        hh = sha256_hash.hexdigest()
+
+    output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
+    os.replace(tmp_path, output_path)
+
+    return output_path
+
+
+def reduce_across_processes(val):
+    if not is_dist_avail_and_initialized():
+        # nothing to sync, but we still convert to tensor for consistency with the distributed case.
+        return torch.tensor(val)
+
+    t = torch.tensor(val, device="cuda")
+    dist.barrier()
+    dist.all_reduce(t)
+    return t
+
+
+def set_weight_decay(
+    model: torch.nn.Module,
+    weight_decay: float,
+    norm_weight_decay: Optional[float] = None,
+    norm_classes: Optional[List[type]] = None,
+    custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
+):
+    if not norm_classes:
+        norm_classes = [
+            torch.nn.modules.batchnorm._BatchNorm,
+            torch.nn.LayerNorm,
+            torch.nn.GroupNorm,
+            torch.nn.modules.instancenorm._InstanceNorm,
+            torch.nn.LocalResponseNorm,
+        ]
+    norm_classes = tuple(norm_classes)
+
+    params = {
+        "other": [],
+        "norm": [],
+    }
+    params_weight_decay = {
+        "other": weight_decay,
+        "norm": norm_weight_decay,
+    }
+    custom_keys = []
+    if custom_keys_weight_decay is not None:
+        for key, weight_decay in custom_keys_weight_decay:
+            params[key] = []
+            params_weight_decay[key] = weight_decay
+            custom_keys.append(key)
+
+    def _add_params(module, prefix=""):
+        for name, p in module.named_parameters(recurse=False):
+            if not p.requires_grad:
+                continue
+            is_custom_key = False
+            for key in custom_keys:
+                target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
+                if key == target_name:
+                    params[key].append(p)
+                    is_custom_key = True
+                    break
+            if not is_custom_key:
+                if norm_weight_decay is not None and isinstance(module, norm_classes):
+                    params["norm"].append(p)
+                else:
+                    params["other"].append(p)
+
+        for child_name, child_module in module.named_children():
+            child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
+            _add_params(child_module, prefix=child_prefix)
+
+    _add_params(model)
+
+    param_groups = []
+    for key in params:
+        if len(params[key]) > 0:
+            param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
+    return param_groups