train_quantization.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import copy
  2. import datetime
  3. import os
  4. import time
  5. import torch
  6. import torch.ao.quantization
  7. import torch.utils.data
  8. import torchvision
  9. import utils
  10. from torch import nn
  11. from train import evaluate, load_data, train_one_epoch
  12. def main(args):
  13. if args.output_dir:
  14. utils.mkdir(args.output_dir)
  15. utils.init_distributed_mode(args)
  16. print(args)
  17. if args.post_training_quantize and args.distributed:
  18. raise RuntimeError("Post training quantization example should not be performed on distributed mode")
  19. # Set backend engine to ensure that quantized model runs on the correct kernels
  20. if args.qbackend not in torch.backends.quantized.supported_engines:
  21. raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
  22. torch.backends.quantized.engine = args.qbackend
  23. device = torch.device(args.device)
  24. torch.backends.cudnn.benchmark = True
  25. # Data loading code
  26. print("Loading data")
  27. train_dir = os.path.join(args.data_path, "train")
  28. val_dir = os.path.join(args.data_path, "val")
  29. dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
  30. data_loader = torch.utils.data.DataLoader(
  31. dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
  32. )
  33. data_loader_test = torch.utils.data.DataLoader(
  34. dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
  35. )
  36. print("Creating model", args.model)
  37. # when training quantized models, we always start from a pre-trained fp32 reference model
  38. prefix = "quantized_"
  39. model_name = args.model
  40. if not model_name.startswith(prefix):
  41. model_name = prefix + model_name
  42. model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
  43. model.to(device)
  44. if not (args.test_only or args.post_training_quantize):
  45. model.fuse_model(is_qat=True)
  46. model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
  47. torch.ao.quantization.prepare_qat(model, inplace=True)
  48. if args.distributed and args.sync_bn:
  49. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  50. optimizer = torch.optim.SGD(
  51. model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
  52. )
  53. lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
  54. criterion = nn.CrossEntropyLoss()
  55. model_without_ddp = model
  56. if args.distributed:
  57. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  58. model_without_ddp = model.module
  59. if args.resume:
  60. checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
  61. model_without_ddp.load_state_dict(checkpoint["model"])
  62. optimizer.load_state_dict(checkpoint["optimizer"])
  63. lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
  64. args.start_epoch = checkpoint["epoch"] + 1
  65. if args.post_training_quantize:
  66. # perform calibration on a subset of the training dataset
  67. # for that, create a subset of the training dataset
  68. ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
  69. data_loader_calibration = torch.utils.data.DataLoader(
  70. ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
  71. )
  72. model.eval()
  73. model.fuse_model(is_qat=False)
  74. model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend)
  75. torch.ao.quantization.prepare(model, inplace=True)
  76. # Calibrate first
  77. print("Calibrating")
  78. evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
  79. torch.ao.quantization.convert(model, inplace=True)
  80. if args.output_dir:
  81. print("Saving quantized model")
  82. if utils.is_main_process():
  83. torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
  84. print("Evaluating post-training quantized model")
  85. evaluate(model, criterion, data_loader_test, device=device)
  86. return
  87. if args.test_only:
  88. evaluate(model, criterion, data_loader_test, device=device)
  89. return
  90. model.apply(torch.ao.quantization.enable_observer)
  91. model.apply(torch.ao.quantization.enable_fake_quant)
  92. start_time = time.time()
  93. for epoch in range(args.start_epoch, args.epochs):
  94. if args.distributed:
  95. train_sampler.set_epoch(epoch)
  96. print("Starting training for epoch", epoch)
  97. train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
  98. lr_scheduler.step()
  99. with torch.inference_mode():
  100. if epoch >= args.num_observer_update_epochs:
  101. print("Disabling observer for subseq epochs, epoch = ", epoch)
  102. model.apply(torch.ao.quantization.disable_observer)
  103. if epoch >= args.num_batch_norm_update_epochs:
  104. print("Freezing BN for subseq epochs, epoch = ", epoch)
  105. model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
  106. print("Evaluate QAT model")
  107. evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
  108. quantized_eval_model = copy.deepcopy(model_without_ddp)
  109. quantized_eval_model.eval()
  110. quantized_eval_model.to(torch.device("cpu"))
  111. torch.ao.quantization.convert(quantized_eval_model, inplace=True)
  112. print("Evaluate Quantized model")
  113. evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
  114. model.train()
  115. if args.output_dir:
  116. checkpoint = {
  117. "model": model_without_ddp.state_dict(),
  118. "eval_model": quantized_eval_model.state_dict(),
  119. "optimizer": optimizer.state_dict(),
  120. "lr_scheduler": lr_scheduler.state_dict(),
  121. "epoch": epoch,
  122. "args": args,
  123. }
  124. utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
  125. utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
  126. print("Saving models after epoch ", epoch)
  127. total_time = time.time() - start_time
  128. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  129. print(f"Training time {total_time_str}")
  130. def get_args_parser(add_help=True):
  131. import argparse
  132. parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
  133. parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
  134. parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
  135. parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack")
  136. parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
  137. parser.add_argument(
  138. "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  139. )
  140. parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
  141. parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
  142. parser.add_argument(
  143. "--num-observer-update-epochs",
  144. default=4,
  145. type=int,
  146. metavar="N",
  147. help="number of total epochs to update observers",
  148. )
  149. parser.add_argument(
  150. "--num-batch-norm-update-epochs",
  151. default=3,
  152. type=int,
  153. metavar="N",
  154. help="number of total epochs to update batch norm stats",
  155. )
  156. parser.add_argument(
  157. "--num-calibration-batches",
  158. default=32,
  159. type=int,
  160. metavar="N",
  161. help="number of batches of training set for \
  162. observer calibration ",
  163. )
  164. parser.add_argument(
  165. "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
  166. )
  167. parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
  168. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  169. parser.add_argument(
  170. "--wd",
  171. "--weight-decay",
  172. default=1e-4,
  173. type=float,
  174. metavar="W",
  175. help="weight decay (default: 1e-4)",
  176. dest="weight_decay",
  177. )
  178. parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
  179. parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
  180. parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
  181. parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
  182. parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
  183. parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
  184. parser.add_argument(
  185. "--cache-dataset",
  186. dest="cache_dataset",
  187. help="Cache the datasets for quicker initialization. \
  188. It also serializes the transforms",
  189. action="store_true",
  190. )
  191. parser.add_argument(
  192. "--sync-bn",
  193. dest="sync_bn",
  194. help="Use sync batch norm",
  195. action="store_true",
  196. )
  197. parser.add_argument(
  198. "--test-only",
  199. dest="test_only",
  200. help="Only test the model",
  201. action="store_true",
  202. )
  203. parser.add_argument(
  204. "--post-training-quantize",
  205. dest="post_training_quantize",
  206. help="Post training quantize the model",
  207. action="store_true",
  208. )
  209. # distributed training parameters
  210. parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
  211. parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
  212. parser.add_argument(
  213. "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
  214. )
  215. parser.add_argument(
  216. "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
  217. )
  218. parser.add_argument(
  219. "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
  220. )
  221. parser.add_argument(
  222. "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
  223. )
  224. parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
  225. parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
  226. parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
  227. parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
  228. return parser
  229. if __name__ == "__main__":
  230. args = get_args_parser().parse_args()
  231. if args.backend in ("fbgemm", "qnnpack"):
  232. raise ValueError(
  233. "The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) "
  234. "instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend."
  235. )
  236. main(args)