general.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import shutil
  14. import signal
  15. import time
  16. import urllib
  17. from itertools import repeat
  18. from multiprocessing.pool import ThreadPool
  19. from pathlib import Path
  20. from subprocess import check_output
  21. from zipfile import ZipFile
  22. import cv2
  23. import numpy as np
  24. import pandas as pd
  25. import pkg_resources as pkg
  26. import torch
  27. import torchvision
  28. import yaml
  29. from utils.downloads import gsutil_getsize
  30. from utils.metrics import box_iou, fitness
  31. # Settings
  32. FILE = Path(__file__).resolve()
  33. ROOT = FILE.parents[1] # YOLOv5 root directory
  34. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  35. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  36. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  37. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  38. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  39. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  40. pd.options.display.max_columns = 10
  41. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  42. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  43. def is_kaggle():
  44. # Is environment a Kaggle Notebook?
  45. try:
  46. assert os.environ.get('PWD') == '/kaggle/working'
  47. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  48. return True
  49. except AssertionError:
  50. return False
  51. def is_writeable(dir, test=False):
  52. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  53. if test: # method 1
  54. file = Path(dir) / 'tmp.txt'
  55. try:
  56. with open(file, 'w'): # open file with write permissions
  57. pass
  58. file.unlink() # remove file
  59. return True
  60. except OSError:
  61. return False
  62. else: # method 2
  63. return os.access(dir, os.R_OK) # possible issues on Windows
  64. def set_logging(name=None, verbose=VERBOSE):
  65. # Sets level and returns logger
  66. if is_kaggle():
  67. for h in logging.root.handlers:
  68. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  69. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  70. logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
  71. return logging.getLogger(name)
  72. LOGGER = set_logging('yolov5') # define globally (used in train.py, val.py, detect.py, etc.)
  73. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  74. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  75. env = os.getenv(env_var)
  76. if env:
  77. path = Path(env) # use environment variable
  78. else:
  79. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  80. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  81. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  82. path.mkdir(exist_ok=True) # make if required
  83. return path
  84. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  85. class Profile(contextlib.ContextDecorator):
  86. # Usage: @Profile() decorator or 'with Profile():' context manager
  87. def __enter__(self):
  88. self.start = time.time()
  89. def __exit__(self, type, value, traceback):
  90. print(f'Profile results: {time.time() - self.start:.5f}s')
  91. class Timeout(contextlib.ContextDecorator):
  92. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  93. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  94. self.seconds = int(seconds)
  95. self.timeout_message = timeout_msg
  96. self.suppress = bool(suppress_timeout_errors)
  97. def _timeout_handler(self, signum, frame):
  98. raise TimeoutError(self.timeout_message)
  99. def __enter__(self):
  100. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  101. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  102. def __exit__(self, exc_type, exc_val, exc_tb):
  103. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  104. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  105. return True
  106. class WorkingDirectory(contextlib.ContextDecorator):
  107. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  108. def __init__(self, new_dir):
  109. self.dir = new_dir # new dir
  110. self.cwd = Path.cwd().resolve() # current dir
  111. def __enter__(self):
  112. os.chdir(self.dir)
  113. def __exit__(self, exc_type, exc_val, exc_tb):
  114. os.chdir(self.cwd)
  115. def try_except(func):
  116. # try-except function. Usage: @try_except decorator
  117. def handler(*args, **kwargs):
  118. try:
  119. func(*args, **kwargs)
  120. except Exception as e:
  121. print(e)
  122. return handler
  123. def methods(instance):
  124. # Get class/instance methods
  125. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  126. def print_args(name, opt):
  127. # Print argparser arguments
  128. LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  129. def init_seeds(seed=0):
  130. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  131. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  132. import torch.backends.cudnn as cudnn
  133. random.seed(seed)
  134. np.random.seed(seed)
  135. torch.manual_seed(seed)
  136. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  137. def intersect_dicts(da, db, exclude=()):
  138. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  139. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  140. def get_latest_run(search_dir='.'):
  141. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  142. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  143. return max(last_list, key=os.path.getctime) if last_list else ''
  144. def is_docker():
  145. # Is environment a Docker container?
  146. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  147. def is_colab():
  148. # Is environment a Google Colab instance?
  149. try:
  150. import google.colab
  151. return True
  152. except ImportError:
  153. return False
  154. def is_pip():
  155. # Is file in a pip package?
  156. return 'site-packages' in Path(__file__).resolve().parts
  157. def is_ascii(s=''):
  158. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  159. s = str(s) # convert list, tuple, None, etc. to str
  160. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  161. def is_chinese(s='人工智能'):
  162. # Is string composed of any Chinese characters?
  163. return True if re.search('[\u4e00-\u9fff]', str(s)) else False
  164. def emojis(str=''):
  165. # Return platform-dependent emoji-safe version of string
  166. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  167. def file_size(path):
  168. # Return file/dir size (MB)
  169. path = Path(path)
  170. if path.is_file():
  171. return path.stat().st_size / 1E6
  172. elif path.is_dir():
  173. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  174. else:
  175. return 0.0
  176. def check_online():
  177. # Check internet connectivity
  178. import socket
  179. try:
  180. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  181. return True
  182. except OSError:
  183. return False
  184. @try_except
  185. @WorkingDirectory(ROOT)
  186. def check_git_status():
  187. # Recommend 'git pull' if code is out of date
  188. msg = ', for updates see https://github.com/ultralytics/yolov5'
  189. s = colorstr('github: ') # string
  190. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  191. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  192. assert check_online(), s + 'skipping check (offline)' + msg
  193. cmd = 'git fetch && git config --get remote.origin.url'
  194. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  195. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  196. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  197. if n > 0:
  198. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  199. else:
  200. s += f'up to date with {url} ✅'
  201. LOGGER.info(emojis(s)) # emoji-safe
  202. def check_python(minimum='3.6.2'):
  203. # Check current python version vs. required python version
  204. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  205. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  206. # Check version vs. required version
  207. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  208. result = (current == minimum) if pinned else (current >= minimum) # bool
  209. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  210. if hard:
  211. assert result, s # assert min requirements met
  212. if verbose and not result:
  213. LOGGER.warning(s)
  214. return result
  215. @try_except
  216. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  217. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  218. prefix = colorstr('red', 'bold', 'requirements:')
  219. check_python() # check python version
  220. if isinstance(requirements, (str, Path)): # requirements.txt file
  221. file = Path(requirements)
  222. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  223. with file.open() as f:
  224. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  225. else: # list or tuple of packages
  226. requirements = [x for x in requirements if x not in exclude]
  227. n = 0 # number of packages updates
  228. for r in requirements:
  229. try:
  230. pkg.require(r)
  231. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  232. s = f"{prefix} {r} not found and is required by YOLOv5"
  233. if install:
  234. LOGGER.info(f"{s}, attempting auto-update...")
  235. try:
  236. assert check_online(), f"'pip install {r}' skipped (offline)"
  237. LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
  238. n += 1
  239. except Exception as e:
  240. LOGGER.warning(f'{prefix} {e}')
  241. else:
  242. LOGGER.info(f'{s}. Please install and rerun your command.')
  243. if n: # if packages updated
  244. source = file.resolve() if 'file' in locals() else requirements
  245. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  246. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  247. LOGGER.info(emojis(s))
  248. def check_img_size(imgsz, s=32, floor=0):
  249. # Verify image size is a multiple of stride s in each dimension
  250. if isinstance(imgsz, int): # integer i.e. img_size=640
  251. new_size = max(make_divisible(imgsz, int(s)), floor)
  252. else: # list i.e. img_size=[640, 480]
  253. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  254. if new_size != imgsz:
  255. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  256. return new_size
  257. def check_imshow():
  258. # Check if environment supports image displays
  259. try:
  260. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  261. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  262. cv2.imshow('test', np.zeros((1, 1, 3)))
  263. cv2.waitKey(1)
  264. cv2.destroyAllWindows()
  265. cv2.waitKey(1)
  266. return True
  267. except Exception as e:
  268. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  269. return False
  270. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  271. # Check file(s) for acceptable suffix
  272. if file and suffix:
  273. if isinstance(suffix, str):
  274. suffix = [suffix]
  275. for f in file if isinstance(file, (list, tuple)) else [file]:
  276. s = Path(f).suffix.lower() # file suffix
  277. if len(s):
  278. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  279. def check_yaml(file, suffix=('.yaml', '.yml')):
  280. # Search/download YAML file (if necessary) and return path, checking suffix
  281. return check_file(file, suffix)
  282. def check_file(file, suffix=''):
  283. # Search/download file (if necessary) and return path
  284. check_suffix(file, suffix) # optional
  285. file = str(file) # convert to str()
  286. if Path(file).is_file() or file == '': # exists
  287. return file
  288. elif file.startswith(('http:/', 'https:/')): # download
  289. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  290. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  291. if Path(file).is_file():
  292. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  293. else:
  294. LOGGER.info(f'Downloading {url} to {file}...')
  295. torch.hub.download_url_to_file(url, file)
  296. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  297. return file
  298. else: # search
  299. files = []
  300. for d in 'data', 'models', 'utils': # search directories
  301. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  302. assert len(files), f'File not found: {file}' # assert file was found
  303. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  304. return files[0] # return file
  305. def check_font(font=FONT):
  306. # Download font to CONFIG_DIR if necessary
  307. font = Path(font)
  308. if not font.exists() and not (CONFIG_DIR / font.name).exists():
  309. url = "https://ultralytics.com/assets/" + font.name
  310. LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
  311. torch.hub.download_url_to_file(url, str(font), progress=False)
  312. def check_dataset(data, autodownload=True):
  313. # Download and/or unzip dataset if not found locally
  314. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  315. # Download (optional)
  316. extract_dir = ''
  317. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  318. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  319. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  320. extract_dir, autodownload = data.parent, False
  321. # Read yaml (optional)
  322. if isinstance(data, (str, Path)):
  323. with open(data, errors='ignore') as f:
  324. data = yaml.safe_load(f) # dictionary
  325. # Resolve paths
  326. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  327. if not path.is_absolute():
  328. path = (ROOT / path).resolve()
  329. for k in 'train', 'val', 'test':
  330. if data.get(k): # prepend path
  331. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  332. # Parse yaml
  333. assert 'nc' in data, "Dataset 'nc' key missing."
  334. if 'names' not in data:
  335. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  336. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  337. if val:
  338. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  339. if not all(x.exists() for x in val):
  340. LOGGER.info('\nDataset not found, missing paths: %s' % [str(x) for x in val if not x.exists()])
  341. if s and autodownload: # download script
  342. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  343. if s.startswith('http') and s.endswith('.zip'): # URL
  344. f = Path(s).name # filename
  345. LOGGER.info(f'Downloading {s} to {f}...')
  346. torch.hub.download_url_to_file(s, f)
  347. Path(root).mkdir(parents=True, exist_ok=True) # create root
  348. ZipFile(f).extractall(path=root) # unzip
  349. Path(f).unlink() # remove zip
  350. r = None # success
  351. elif s.startswith('bash '): # bash script
  352. LOGGER.info(f'Running {s} ...')
  353. r = os.system(s)
  354. else: # python script
  355. r = exec(s, {'yaml': data}) # return None
  356. LOGGER.info(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
  357. else:
  358. raise Exception('Dataset not found.')
  359. return data # dictionary
  360. def url2file(url):
  361. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  362. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  363. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  364. return file
  365. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  366. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  367. def download_one(url, dir):
  368. # Download 1 file
  369. f = dir / Path(url).name # filename
  370. if Path(url).is_file(): # exists in current path
  371. Path(url).rename(f) # move to dir
  372. elif not f.exists():
  373. LOGGER.info(f'Downloading {url} to {f}...')
  374. if curl:
  375. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  376. else:
  377. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  378. if unzip and f.suffix in ('.zip', '.gz'):
  379. LOGGER.info(f'Unzipping {f}...')
  380. if f.suffix == '.zip':
  381. ZipFile(f).extractall(path=dir) # unzip
  382. elif f.suffix == '.gz':
  383. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  384. if delete:
  385. f.unlink() # remove zip
  386. dir = Path(dir)
  387. dir.mkdir(parents=True, exist_ok=True) # make directory
  388. if threads > 1:
  389. pool = ThreadPool(threads)
  390. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  391. pool.close()
  392. pool.join()
  393. else:
  394. for u in [url] if isinstance(url, (str, Path)) else url:
  395. download_one(u, dir)
  396. def make_divisible(x, divisor):
  397. # Returns nearest x divisible by divisor
  398. if isinstance(divisor, torch.Tensor):
  399. divisor = int(divisor.max()) # to int
  400. return math.ceil(x / divisor) * divisor
  401. def clean_str(s):
  402. # Cleans a string by replacing special characters with underscore _
  403. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  404. def one_cycle(y1=0.0, y2=1.0, steps=100):
  405. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  406. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  407. def colorstr(*input):
  408. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  409. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  410. colors = {'black': '\033[30m', # basic colors
  411. 'red': '\033[31m',
  412. 'green': '\033[32m',
  413. 'yellow': '\033[33m',
  414. 'blue': '\033[34m',
  415. 'magenta': '\033[35m',
  416. 'cyan': '\033[36m',
  417. 'white': '\033[37m',
  418. 'bright_black': '\033[90m', # bright colors
  419. 'bright_red': '\033[91m',
  420. 'bright_green': '\033[92m',
  421. 'bright_yellow': '\033[93m',
  422. 'bright_blue': '\033[94m',
  423. 'bright_magenta': '\033[95m',
  424. 'bright_cyan': '\033[96m',
  425. 'bright_white': '\033[97m',
  426. 'end': '\033[0m', # misc
  427. 'bold': '\033[1m',
  428. 'underline': '\033[4m'}
  429. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  430. def labels_to_class_weights(labels, nc=80):
  431. # Get class weights (inverse frequency) from training labels
  432. if labels[0] is None: # no labels loaded
  433. return torch.Tensor()
  434. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  435. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  436. weights = np.bincount(classes, minlength=nc) # occurrences per class
  437. # Prepend gridpoint count (for uCE training)
  438. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  439. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  440. weights[weights == 0] = 1 # replace empty bins with 1
  441. weights = 1 / weights # number of targets per class
  442. weights /= weights.sum() # normalize
  443. return torch.from_numpy(weights)
  444. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  445. # Produces image weights based on class_weights and image contents
  446. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  447. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  448. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  449. return image_weights
  450. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  451. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  452. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  453. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  454. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  455. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  456. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  457. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  458. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  459. return x
  460. def xyxy2xywh(x):
  461. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  462. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  463. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  464. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  465. y[:, 2] = x[:, 2] - x[:, 0] # width
  466. y[:, 3] = x[:, 3] - x[:, 1] # height
  467. return y
  468. def xywh2xyxy(x):
  469. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  470. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  471. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  472. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  473. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  474. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  475. return y
  476. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  477. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  478. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  479. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  480. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  481. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  482. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  483. return y
  484. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  485. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  486. if clip:
  487. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  488. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  489. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  490. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  491. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  492. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  493. return y
  494. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  495. # Convert normalized segments into pixel segments, shape (n,2)
  496. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  497. y[:, 0] = w * x[:, 0] + padw # top left x
  498. y[:, 1] = h * x[:, 1] + padh # top left y
  499. return y
  500. def segment2box(segment, width=640, height=640):
  501. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  502. x, y = segment.T # segment xy
  503. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  504. x, y, = x[inside], y[inside]
  505. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  506. def segments2boxes(segments):
  507. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  508. boxes = []
  509. for s in segments:
  510. x, y = s.T # segment xy
  511. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  512. return xyxy2xywh(np.array(boxes)) # cls, xywh
  513. def resample_segments(segments, n=1000):
  514. # Up-sample an (n,2) segment
  515. for i, s in enumerate(segments):
  516. x = np.linspace(0, len(s) - 1, n)
  517. xp = np.arange(len(s))
  518. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  519. return segments
  520. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  521. # Rescale coords (xyxy) from img1_shape to img0_shape
  522. if ratio_pad is None: # calculate from img0_shape
  523. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  524. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  525. else:
  526. gain = ratio_pad[0][0]
  527. pad = ratio_pad[1]
  528. coords[:, [0, 2]] -= pad[0] # x padding
  529. coords[:, [1, 3]] -= pad[1] # y padding
  530. coords[:, :4] /= gain
  531. clip_coords(coords, img0_shape)
  532. return coords
  533. def clip_coords(boxes, shape):
  534. # Clip bounding xyxy bounding boxes to image shape (height, width)
  535. if isinstance(boxes, torch.Tensor): # faster individually
  536. boxes[:, 0].clamp_(0, shape[1]) # x1
  537. boxes[:, 1].clamp_(0, shape[0]) # y1
  538. boxes[:, 2].clamp_(0, shape[1]) # x2
  539. boxes[:, 3].clamp_(0, shape[0]) # y2
  540. else: # np.array (faster grouped)
  541. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  542. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  543. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  544. labels=(), max_det=300):
  545. """Runs Non-Maximum Suppression (NMS) on inference results
  546. Returns:
  547. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  548. """
  549. nc = prediction.shape[2] - 5 # number of classes
  550. xc = prediction[..., 4] > conf_thres # candidates
  551. # Checks
  552. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  553. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  554. # Settings
  555. min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
  556. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  557. time_limit = 10.0 # seconds to quit after
  558. redundant = True # require redundant detections
  559. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  560. merge = False # use merge-NMS
  561. t = time.time()
  562. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  563. for xi, x in enumerate(prediction): # image index, image inference
  564. # Apply constraints
  565. x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  566. x = x[xc[xi]] # confidence
  567. # Cat apriori labels if autolabelling
  568. if labels and len(labels[xi]):
  569. lb = labels[xi]
  570. v = torch.zeros((len(lb), nc + 5), device=x.device)
  571. v[:, :4] = lb[:, 1:5] # box
  572. v[:, 4] = 1.0 # conf
  573. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  574. x = torch.cat((x, v), 0)
  575. # If none remain process next image
  576. if not x.shape[0]:
  577. continue
  578. # Compute conf
  579. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  580. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  581. box = xywh2xyxy(x[:, :4])
  582. # Detections matrix nx6 (xyxy, conf, cls)
  583. if multi_label:
  584. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  585. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  586. else: # best class only
  587. conf, j = x[:, 5:].max(1, keepdim=True)
  588. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  589. # Filter by class
  590. if classes is not None:
  591. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  592. # Apply finite constraint
  593. # if not torch.isfinite(x).all():
  594. # x = x[torch.isfinite(x).all(1)]
  595. # Check shape
  596. n = x.shape[0] # number of boxes
  597. if not n: # no boxes
  598. continue
  599. elif n > max_nms: # excess boxes
  600. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  601. # Batched NMS
  602. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  603. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  604. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  605. if i.shape[0] > max_det: # limit detections
  606. i = i[:max_det]
  607. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  608. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  609. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  610. weights = iou * scores[None] # box weights
  611. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  612. if redundant:
  613. i = i[iou.sum(1) > 1] # require redundancy
  614. output[xi] = x[i]
  615. if (time.time() - t) > time_limit:
  616. LOGGER.warning(f'WARNING: NMS time limit {time_limit}s exceeded')
  617. break # time limit exceeded
  618. return output
  619. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  620. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  621. x = torch.load(f, map_location=torch.device('cpu'))
  622. if x.get('ema'):
  623. x['model'] = x['ema'] # replace model with ema
  624. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  625. x[k] = None
  626. x['epoch'] = -1
  627. x['model'].half() # to FP16
  628. for p in x['model'].parameters():
  629. p.requires_grad = False
  630. torch.save(x, s or f)
  631. mb = os.path.getsize(s or f) / 1E6 # filesize
  632. LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  633. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  634. evolve_csv = save_dir / 'evolve.csv'
  635. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  636. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  637. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  638. keys = tuple(x.strip() for x in keys)
  639. vals = results + tuple(hyp.values())
  640. n = len(keys)
  641. # Download (optional)
  642. if bucket:
  643. url = f'gs://{bucket}/evolve.csv'
  644. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  645. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  646. # Log to evolve.csv
  647. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  648. with open(evolve_csv, 'a') as f:
  649. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  650. # Save yaml
  651. with open(evolve_yaml, 'w') as f:
  652. data = pd.read_csv(evolve_csv)
  653. data = data.rename(columns=lambda x: x.strip()) # strip keys
  654. i = np.argmax(fitness(data.values[:, :4])) #
  655. generations = len(data)
  656. f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
  657. f'# Best generation: {i}\n' +
  658. f'# Last generation: {generations - 1}\n' +
  659. '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  660. '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  661. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  662. # Print to screen
  663. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' +
  664. prefix + ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' +
  665. prefix + ', '.join(f'{x:20.5g}' for x in vals) + '\n\n')
  666. if bucket:
  667. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  668. def apply_classifier(x, model, img, im0):
  669. # Apply a second stage classifier to YOLO outputs
  670. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  671. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  672. for i, d in enumerate(x): # per image
  673. if d is not None and len(d):
  674. d = d.clone()
  675. # Reshape and pad cutouts
  676. b = xyxy2xywh(d[:, :4]) # boxes
  677. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  678. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  679. d[:, :4] = xywh2xyxy(b).long()
  680. # Rescale boxes from img_size to im0 size
  681. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  682. # Classes
  683. pred_cls1 = d[:, 5].long()
  684. ims = []
  685. for j, a in enumerate(d): # per item
  686. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  687. im = cv2.resize(cutout, (224, 224)) # BGR
  688. # cv2.imwrite('example%i.jpg' % j, cutout)
  689. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  690. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  691. im /= 255 # 0 - 255 to 0.0 - 1.0
  692. ims.append(im)
  693. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  694. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  695. return x
  696. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  697. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  698. path = Path(path) # os-agnostic
  699. if path.exists() and not exist_ok:
  700. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  701. dirs = glob.glob(f"{path}{sep}*") # similar paths
  702. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  703. i = [int(m.groups()[0]) for m in matches if m] # indices
  704. n = max(i) + 1 if i else 2 # increment number
  705. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  706. if mkdir:
  707. path.mkdir(parents=True, exist_ok=True) # make directory
  708. return path
  709. # Variables
  710. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm