utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. import copy
  2. import datetime
  3. import errno
  4. import hashlib
  5. import os
  6. import time
  7. from collections import defaultdict, deque, OrderedDict
  8. from typing import List, Optional, Tuple
  9. import torch
  10. import torch.distributed as dist
  11. class SmoothedValue:
  12. """Track a series of values and provide access to smoothed values over a
  13. window or the global series average.
  14. """
  15. def __init__(self, window_size=20, fmt=None):
  16. if fmt is None:
  17. fmt = "{median:.4f} ({global_avg:.4f})"
  18. self.deque = deque(maxlen=window_size)
  19. self.total = 0.0
  20. self.count = 0
  21. self.fmt = fmt
  22. def update(self, value, n=1):
  23. self.deque.append(value)
  24. self.count += n
  25. self.total += value * n
  26. def synchronize_between_processes(self):
  27. """
  28. Warning: does not synchronize the deque!
  29. """
  30. t = reduce_across_processes([self.count, self.total])
  31. t = t.tolist()
  32. self.count = int(t[0])
  33. self.total = t[1]
  34. @property
  35. def median(self):
  36. d = torch.tensor(list(self.deque))
  37. return d.median().item()
  38. @property
  39. def avg(self):
  40. d = torch.tensor(list(self.deque), dtype=torch.float32)
  41. return d.mean().item()
  42. @property
  43. def global_avg(self):
  44. return self.total / self.count
  45. @property
  46. def max(self):
  47. return max(self.deque)
  48. @property
  49. def value(self):
  50. return self.deque[-1]
  51. def __str__(self):
  52. return self.fmt.format(
  53. median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
  54. )
  55. class MetricLogger:
  56. def __init__(self, delimiter="\t"):
  57. self.meters = defaultdict(SmoothedValue)
  58. self.delimiter = delimiter
  59. def update(self, **kwargs):
  60. for k, v in kwargs.items():
  61. if isinstance(v, torch.Tensor):
  62. v = v.item()
  63. assert isinstance(v, (float, int))
  64. self.meters[k].update(v)
  65. def __getattr__(self, attr):
  66. if attr in self.meters:
  67. return self.meters[attr]
  68. if attr in self.__dict__:
  69. return self.__dict__[attr]
  70. raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
  71. def __str__(self):
  72. loss_str = []
  73. for name, meter in self.meters.items():
  74. loss_str.append(f"{name}: {str(meter)}")
  75. return self.delimiter.join(loss_str)
  76. def synchronize_between_processes(self):
  77. for meter in self.meters.values():
  78. meter.synchronize_between_processes()
  79. def add_meter(self, name, meter):
  80. self.meters[name] = meter
  81. def log_every(self, iterable, print_freq, header=None):
  82. i = 0
  83. if not header:
  84. header = ""
  85. start_time = time.time()
  86. end = time.time()
  87. iter_time = SmoothedValue(fmt="{avg:.4f}")
  88. data_time = SmoothedValue(fmt="{avg:.4f}")
  89. space_fmt = ":" + str(len(str(len(iterable)))) + "d"
  90. if torch.cuda.is_available():
  91. log_msg = self.delimiter.join(
  92. [
  93. header,
  94. "[{0" + space_fmt + "}/{1}]",
  95. "eta: {eta}",
  96. "{meters}",
  97. "time: {time}",
  98. "data: {data}",
  99. "max mem: {memory:.0f}",
  100. ]
  101. )
  102. else:
  103. log_msg = self.delimiter.join(
  104. [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
  105. )
  106. MB = 1024.0 * 1024.0
  107. for obj in iterable:
  108. data_time.update(time.time() - end)
  109. yield obj
  110. iter_time.update(time.time() - end)
  111. if i % print_freq == 0:
  112. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  113. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  114. if torch.cuda.is_available():
  115. print(
  116. log_msg.format(
  117. i,
  118. len(iterable),
  119. eta=eta_string,
  120. meters=str(self),
  121. time=str(iter_time),
  122. data=str(data_time),
  123. memory=torch.cuda.max_memory_allocated() / MB,
  124. )
  125. )
  126. else:
  127. print(
  128. log_msg.format(
  129. i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
  130. )
  131. )
  132. i += 1
  133. end = time.time()
  134. total_time = time.time() - start_time
  135. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  136. print(f"{header} Total time: {total_time_str}")
  137. class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
  138. """Maintains moving averages of model parameters using an exponential decay.
  139. ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
  140. `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
  141. is used to compute the EMA.
  142. """
  143. def __init__(self, model, decay, device="cpu"):
  144. def ema_avg(avg_model_param, model_param, num_averaged):
  145. return decay * avg_model_param + (1 - decay) * model_param
  146. super().__init__(model, device, ema_avg, use_buffers=True)
  147. def accuracy(output, target, topk=(1,)):
  148. """Computes the accuracy over the k top predictions for the specified values of k"""
  149. with torch.inference_mode():
  150. maxk = max(topk)
  151. batch_size = target.size(0)
  152. if target.ndim == 2:
  153. target = target.max(dim=1)[1]
  154. _, pred = output.topk(maxk, 1, True, True)
  155. pred = pred.t()
  156. correct = pred.eq(target[None])
  157. res = []
  158. for k in topk:
  159. correct_k = correct[:k].flatten().sum(dtype=torch.float32)
  160. res.append(correct_k * (100.0 / batch_size))
  161. return res
  162. def mkdir(path):
  163. try:
  164. os.makedirs(path)
  165. except OSError as e:
  166. if e.errno != errno.EEXIST:
  167. raise
  168. def setup_for_distributed(is_master):
  169. """
  170. This function disables printing when not in master process
  171. """
  172. import builtins as __builtin__
  173. builtin_print = __builtin__.print
  174. def print(*args, **kwargs):
  175. force = kwargs.pop("force", False)
  176. if is_master or force:
  177. builtin_print(*args, **kwargs)
  178. __builtin__.print = print
  179. def is_dist_avail_and_initialized():
  180. if not dist.is_available():
  181. return False
  182. if not dist.is_initialized():
  183. return False
  184. return True
  185. def get_world_size():
  186. if not is_dist_avail_and_initialized():
  187. return 1
  188. return dist.get_world_size()
  189. def get_rank():
  190. if not is_dist_avail_and_initialized():
  191. return 0
  192. return dist.get_rank()
  193. def is_main_process():
  194. return get_rank() == 0
  195. def save_on_master(*args, **kwargs):
  196. if is_main_process():
  197. torch.save(*args, **kwargs)
  198. def init_distributed_mode(args):
  199. if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
  200. args.rank = int(os.environ["RANK"])
  201. args.world_size = int(os.environ["WORLD_SIZE"])
  202. args.gpu = int(os.environ["LOCAL_RANK"])
  203. elif "SLURM_PROCID" in os.environ:
  204. args.rank = int(os.environ["SLURM_PROCID"])
  205. args.gpu = args.rank % torch.cuda.device_count()
  206. elif hasattr(args, "rank"):
  207. pass
  208. else:
  209. print("Not using distributed mode")
  210. args.distributed = False
  211. return
  212. args.distributed = True
  213. torch.cuda.set_device(args.gpu)
  214. args.dist_backend = "nccl"
  215. print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
  216. torch.distributed.init_process_group(
  217. backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
  218. )
  219. torch.distributed.barrier()
  220. setup_for_distributed(args.rank == 0)
  221. def average_checkpoints(inputs):
  222. """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
  223. https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
  224. Args:
  225. inputs (List[str]): An iterable of string paths of checkpoints to load from.
  226. Returns:
  227. A dict of string keys mapping to various values. The 'model' key
  228. from the returned dict should correspond to an OrderedDict mapping
  229. string parameter names to torch Tensors.
  230. """
  231. params_dict = OrderedDict()
  232. params_keys = None
  233. new_state = None
  234. num_models = len(inputs)
  235. for fpath in inputs:
  236. with open(fpath, "rb") as f:
  237. state = torch.load(
  238. f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True
  239. )
  240. # Copies over the settings from the first checkpoint
  241. if new_state is None:
  242. new_state = state
  243. model_params = state["model"]
  244. model_params_keys = list(model_params.keys())
  245. if params_keys is None:
  246. params_keys = model_params_keys
  247. elif params_keys != model_params_keys:
  248. raise KeyError(
  249. f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
  250. )
  251. for k in params_keys:
  252. p = model_params[k]
  253. if isinstance(p, torch.HalfTensor):
  254. p = p.float()
  255. if k not in params_dict:
  256. params_dict[k] = p.clone()
  257. # NOTE: clone() is needed in case of p is a shared parameter
  258. else:
  259. params_dict[k] += p
  260. averaged_params = OrderedDict()
  261. for k, v in params_dict.items():
  262. averaged_params[k] = v
  263. if averaged_params[k].is_floating_point():
  264. averaged_params[k].div_(num_models)
  265. else:
  266. averaged_params[k] //= num_models
  267. new_state["model"] = averaged_params
  268. return new_state
  269. def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
  270. """
  271. This method can be used to prepare weights files for new models. It receives as
  272. input a model architecture and a checkpoint from the training script and produces
  273. a file with the weights ready for release.
  274. Examples:
  275. from torchvision import models as M
  276. # Classification
  277. model = M.mobilenet_v3_large(weights=None)
  278. print(store_model_weights(model, './class.pth'))
  279. # Quantized Classification
  280. model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
  281. model.fuse_model(is_qat=True)
  282. model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
  283. _ = torch.ao.quantization.prepare_qat(model, inplace=True)
  284. print(store_model_weights(model, './qat.pth'))
  285. # Object Detection
  286. model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
  287. print(store_model_weights(model, './obj.pth'))
  288. # Segmentation
  289. model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
  290. print(store_model_weights(model, './segm.pth', strict=False))
  291. Args:
  292. model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
  293. checkpoint_path (str): The path of the checkpoint we will load.
  294. checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
  295. Default: "model".
  296. strict (bool): whether to strictly enforce that the keys
  297. in :attr:`state_dict` match the keys returned by this module's
  298. :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
  299. Returns:
  300. output_path (str): The location where the weights are saved.
  301. """
  302. # Store the new model next to the checkpoint_path
  303. checkpoint_path = os.path.abspath(checkpoint_path)
  304. output_dir = os.path.dirname(checkpoint_path)
  305. # Deep copy to avoid side effects on the model object.
  306. model = copy.deepcopy(model)
  307. checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
  308. # Load the weights to the model to validate that everything works
  309. # and remove unnecessary weights (such as auxiliaries, etc.)
  310. if checkpoint_key == "model_ema":
  311. del checkpoint[checkpoint_key]["n_averaged"]
  312. torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
  313. model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
  314. tmp_path = os.path.join(output_dir, str(model.__hash__()))
  315. torch.save(model.state_dict(), tmp_path)
  316. sha256_hash = hashlib.sha256()
  317. with open(tmp_path, "rb") as f:
  318. # Read and update hash string value in blocks of 4K
  319. for byte_block in iter(lambda: f.read(4096), b""):
  320. sha256_hash.update(byte_block)
  321. hh = sha256_hash.hexdigest()
  322. output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
  323. os.replace(tmp_path, output_path)
  324. return output_path
  325. def reduce_across_processes(val):
  326. if not is_dist_avail_and_initialized():
  327. # nothing to sync, but we still convert to tensor for consistency with the distributed case.
  328. return torch.tensor(val)
  329. t = torch.tensor(val, device="cuda")
  330. dist.barrier()
  331. dist.all_reduce(t)
  332. return t
  333. def set_weight_decay(
  334. model: torch.nn.Module,
  335. weight_decay: float,
  336. norm_weight_decay: Optional[float] = None,
  337. norm_classes: Optional[List[type]] = None,
  338. custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
  339. ):
  340. if not norm_classes:
  341. norm_classes = [
  342. torch.nn.modules.batchnorm._BatchNorm,
  343. torch.nn.LayerNorm,
  344. torch.nn.GroupNorm,
  345. torch.nn.modules.instancenorm._InstanceNorm,
  346. torch.nn.LocalResponseNorm,
  347. ]
  348. norm_classes = tuple(norm_classes)
  349. params = {
  350. "other": [],
  351. "norm": [],
  352. }
  353. params_weight_decay = {
  354. "other": weight_decay,
  355. "norm": norm_weight_decay,
  356. }
  357. custom_keys = []
  358. if custom_keys_weight_decay is not None:
  359. for key, weight_decay in custom_keys_weight_decay:
  360. params[key] = []
  361. params_weight_decay[key] = weight_decay
  362. custom_keys.append(key)
  363. def _add_params(module, prefix=""):
  364. for name, p in module.named_parameters(recurse=False):
  365. if not p.requires_grad:
  366. continue
  367. is_custom_key = False
  368. for key in custom_keys:
  369. target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
  370. if key == target_name:
  371. params[key].append(p)
  372. is_custom_key = True
  373. break
  374. if not is_custom_key:
  375. if norm_weight_decay is not None and isinstance(module, norm_classes):
  376. params["norm"].append(p)
  377. else:
  378. params["other"].append(p)
  379. for child_name, child_module in module.named_children():
  380. child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
  381. _add_params(child_module, prefix=child_prefix)
  382. _add_params(model)
  383. param_groups = []
  384. for key in params:
  385. if len(params[key]) > 0:
  386. param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
  387. return param_groups