torch_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. PyTorch utils
  4. """
  5. import datetime
  6. import math
  7. import os
  8. import platform
  9. import subprocess
  10. import time
  11. import warnings
  12. from contextlib import contextmanager
  13. from copy import deepcopy
  14. from pathlib import Path
  15. import torch
  16. import torch.distributed as dist
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. from utils.general import LOGGER
  20. try:
  21. import thop # for FLOPs computation
  22. except ImportError:
  23. thop = None
  24. # Suppress PyTorch warnings
  25. warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
  26. @contextmanager
  27. def torch_distributed_zero_first(local_rank: int):
  28. """
  29. Decorator to make all processes in distributed training wait for each local_master to do something.
  30. """
  31. if local_rank not in [-1, 0]:
  32. dist.barrier(device_ids=[local_rank])
  33. yield
  34. if local_rank == 0:
  35. dist.barrier(device_ids=[0])
  36. def date_modified(path=__file__):
  37. # return human-readable file modification date, i.e. '2021-3-26'
  38. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  39. return f'{t.year}-{t.month}-{t.day}'
  40. def git_describe(path=Path(__file__).parent): # path must be a directory
  41. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  42. s = f'git -C {path} describe --tags --long --always'
  43. try:
  44. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  45. except subprocess.CalledProcessError:
  46. return '' # not a git repository
  47. def device_count():
  48. # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux.
  49. assert platform.system() == 'Linux', 'device_count() function only works on Linux'
  50. try:
  51. cmd = 'nvidia-smi -L | wc -l'
  52. return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
  53. except Exception:
  54. return 0
  55. def select_device(device='', batch_size=0, newline=True):
  56. # device = 'cpu' or '0' or '0,1,2,3'
  57. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  58. device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
  59. cpu = device == 'cpu'
  60. if cpu:
  61. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  62. elif device: # non-cpu device requested
  63. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
  64. assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
  65. f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
  66. cuda = not cpu and torch.cuda.is_available()
  67. if cuda:
  68. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  69. n = len(devices) # device count
  70. if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
  71. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  72. space = ' ' * (len(s) + 1)
  73. for i, d in enumerate(devices):
  74. p = torch.cuda.get_device_properties(i)
  75. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MiB)\n" # bytes to MB
  76. else:
  77. s += 'CPU\n'
  78. if not newline:
  79. s = s.rstrip()
  80. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  81. return torch.device('cuda:0' if cuda else 'cpu')
  82. def time_sync():
  83. # pytorch-accurate time
  84. if torch.cuda.is_available():
  85. torch.cuda.synchronize()
  86. return time.time()
  87. def profile(input, ops, n=10, device=None):
  88. # YOLOv5 speed/memory/FLOPs profiler
  89. #
  90. # Usage:
  91. # input = torch.randn(16, 3, 640, 640)
  92. # m1 = lambda x: x * torch.sigmoid(x)
  93. # m2 = nn.SiLU()
  94. # profile(input, [m1, m2], n=100) # profile over 100 iterations
  95. results = []
  96. device = device or select_device()
  97. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  98. f"{'input':>24s}{'output':>24s}")
  99. for x in input if isinstance(input, list) else [input]:
  100. x = x.to(device)
  101. x.requires_grad = True
  102. for m in ops if isinstance(ops, list) else [ops]:
  103. m = m.to(device) if hasattr(m, 'to') else m # device
  104. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  105. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  106. try:
  107. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  108. except Exception:
  109. flops = 0
  110. try:
  111. for _ in range(n):
  112. t[0] = time_sync()
  113. y = m(x)
  114. t[1] = time_sync()
  115. try:
  116. _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  117. t[2] = time_sync()
  118. except Exception: # no backward method
  119. # print(e) # for debug
  120. t[2] = float('nan')
  121. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  122. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  123. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  124. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  125. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  126. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  127. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  128. results.append([p, flops, mem, tf, tb, s_in, s_out])
  129. except Exception as e:
  130. print(e)
  131. results.append(None)
  132. torch.cuda.empty_cache()
  133. return results
  134. def is_parallel(model):
  135. # Returns True if model is of type DP or DDP
  136. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  137. def de_parallel(model):
  138. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  139. return model.module if is_parallel(model) else model
  140. def initialize_weights(model):
  141. for m in model.modules():
  142. t = type(m)
  143. if t is nn.Conv2d:
  144. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  145. elif t is nn.BatchNorm2d:
  146. m.eps = 1e-3
  147. m.momentum = 0.03
  148. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  149. m.inplace = True
  150. def find_modules(model, mclass=nn.Conv2d):
  151. # Finds layer indices matching module class 'mclass'
  152. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  153. def sparsity(model):
  154. # Return global model sparsity
  155. a, b = 0, 0
  156. for p in model.parameters():
  157. a += p.numel()
  158. b += (p == 0).sum()
  159. return b / a
  160. def prune(model, amount=0.3):
  161. # Prune model to requested global sparsity
  162. import torch.nn.utils.prune as prune
  163. print('Pruning model... ', end='')
  164. for name, m in model.named_modules():
  165. if isinstance(m, nn.Conv2d):
  166. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  167. prune.remove(m, 'weight') # make permanent
  168. print(' %.3g global sparsity' % sparsity(model))
  169. def fuse_conv_and_bn(conv, bn):
  170. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  171. fusedconv = nn.Conv2d(conv.in_channels,
  172. conv.out_channels,
  173. kernel_size=conv.kernel_size,
  174. stride=conv.stride,
  175. padding=conv.padding,
  176. groups=conv.groups,
  177. bias=True).requires_grad_(False).to(conv.weight.device)
  178. # prepare filters
  179. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  180. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  181. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  182. # prepare spatial bias
  183. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  184. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  185. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  186. return fusedconv
  187. def model_info(model, verbose=False, img_size=640):
  188. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  189. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  190. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  191. if verbose:
  192. print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  193. for i, (name, p) in enumerate(model.named_parameters()):
  194. name = name.replace('module_list.', '')
  195. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  196. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  197. try: # FLOPs
  198. from thop import profile
  199. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  200. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  201. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  202. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  203. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  204. except (ImportError, Exception):
  205. fs = ''
  206. LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  207. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  208. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  209. if ratio == 1.0:
  210. return img
  211. else:
  212. h, w = img.shape[2:]
  213. s = (int(h * ratio), int(w * ratio)) # new size
  214. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  215. if not same_shape: # pad/crop img
  216. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  217. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  218. def copy_attr(a, b, include=(), exclude=()):
  219. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  220. for k, v in b.__dict__.items():
  221. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  222. continue
  223. else:
  224. setattr(a, k, v)
  225. class EarlyStopping:
  226. # YOLOv5 simple early stopper
  227. def __init__(self, patience=30):
  228. self.best_fitness = 0.0 # i.e. mAP
  229. self.best_epoch = 0
  230. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  231. self.possible_stop = False # possible stop may occur next epoch
  232. def __call__(self, epoch, fitness):
  233. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  234. self.best_epoch = epoch
  235. self.best_fitness = fitness
  236. delta = epoch - self.best_epoch # epochs without improvement
  237. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  238. stop = delta >= self.patience # stop training if patience exceeded
  239. if stop:
  240. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  241. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  242. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  243. f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
  244. return stop
  245. class ModelEMA:
  246. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  247. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  248. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  249. """
  250. def __init__(self, model, decay=0.9999, updates=0):
  251. # Create EMA
  252. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  253. # if next(model.parameters()).device.type != 'cpu':
  254. # self.ema.half() # FP16 EMA
  255. self.updates = updates # number of EMA updates
  256. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  257. for p in self.ema.parameters():
  258. p.requires_grad_(False)
  259. def update(self, model):
  260. # Update EMA parameters
  261. with torch.no_grad():
  262. self.updates += 1
  263. d = self.decay(self.updates)
  264. msd = de_parallel(model).state_dict() # model state_dict
  265. for k, v in self.ema.state_dict().items():
  266. if v.dtype.is_floating_point:
  267. v *= d
  268. v += (1 - d) * msd[k].detach()
  269. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  270. # Update EMA attributes
  271. copy_attr(self.ema, model, include, exclude)