plots.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Plotting utils
  2. import glob
  3. import math
  4. import os
  5. import random
  6. from copy import copy
  7. from pathlib import Path
  8. import cv2
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import pandas as pd
  13. import seaborn as sns
  14. import torch
  15. import yaml
  16. from PIL import Image, ImageDraw, ImageFont
  17. from scipy.signal import butter, filtfilt
  18. from utils.general import xywh2xyxy, xyxy2xywh
  19. from utils.metrics import fitness
  20. # Settings
  21. matplotlib.rc('font', **{'size': 11})
  22. matplotlib.use('Agg') # for writing to files only
  23. def color_list():
  24. # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  25. def hex2rgb(h):
  26. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  27. return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
  28. def hist2d(x, y, n=100):
  29. # 2d histogram used in labels.png and evolve.png
  30. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  31. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  32. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  33. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  34. return np.log(hist[xidx, yidx])
  35. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  36. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  37. def butter_lowpass(cutoff, fs, order):
  38. nyq = 0.5 * fs
  39. normal_cutoff = cutoff / nyq
  40. return butter(order, normal_cutoff, btype='low', analog=False)
  41. b, a = butter_lowpass(cutoff, fs, order=order)
  42. return filtfilt(b, a, data) # forward-backward filter
  43. def plot_one_box(x, img, color=None, label=None, line_thickness=3):
  44. # Plots one bounding box on image img
  45. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  46. color = color or [random.randint(0, 255) for _ in range(3)]
  47. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  48. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  49. if label:
  50. tf = max(tl - 1, 1) # font thickness
  51. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  52. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  53. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  54. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  55. def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None):
  56. img = Image.fromarray(img)
  57. draw = ImageDraw.Draw(img)
  58. line_thickness = line_thickness or max(int(min(img.size) / 200), 2)
  59. draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot
  60. if label:
  61. fontsize = max(round(max(img.size) / 40), 12)
  62. font = ImageFont.truetype("Arial.ttf", fontsize)
  63. txt_width, txt_height = font.getsize(label)
  64. draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color))
  65. draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
  66. return np.asarray(img)
  67. def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
  68. # Compares the two methods for width-height anchor multiplication
  69. # https://github.com/ultralytics/yolov3/issues/168
  70. x = np.arange(-4.0, 4.0, .1)
  71. ya = np.exp(x)
  72. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  73. fig = plt.figure(figsize=(6, 3), tight_layout=True)
  74. plt.plot(x, ya, '.-', label='YOLOv3')
  75. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  76. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  77. plt.xlim(left=-4, right=4)
  78. plt.ylim(bottom=0, top=6)
  79. plt.xlabel('input')
  80. plt.ylabel('output')
  81. plt.grid()
  82. plt.legend()
  83. fig.savefig('comparison.png', dpi=200)
  84. def output_to_target(output):
  85. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  86. targets = []
  87. for i, o in enumerate(output):
  88. for *box, conf, cls in o.cpu().numpy():
  89. targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
  90. return np.array(targets)
  91. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  92. # Plot image grid with labels
  93. if isinstance(images, torch.Tensor):
  94. images = images.cpu().float().numpy()
  95. if isinstance(targets, torch.Tensor):
  96. targets = targets.cpu().numpy()
  97. # un-normalise
  98. if np.max(images[0]) <= 1:
  99. images *= 255
  100. tl = 3 # line thickness
  101. tf = max(tl - 1, 1) # font thickness
  102. bs, _, h, w = images.shape # batch size, _, height, width
  103. bs = min(bs, max_subplots) # limit plot images
  104. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  105. # Check if we should resize
  106. scale_factor = max_size / max(h, w)
  107. if scale_factor < 1:
  108. h = math.ceil(scale_factor * h)
  109. w = math.ceil(scale_factor * w)
  110. colors = color_list() # list of colors
  111. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  112. for i, img in enumerate(images):
  113. if i == max_subplots: # if last batch has fewer images than we expect
  114. break
  115. block_x = int(w * (i // ns))
  116. block_y = int(h * (i % ns))
  117. img = img.transpose(1, 2, 0)
  118. if scale_factor < 1:
  119. img = cv2.resize(img, (w, h))
  120. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  121. if len(targets) > 0:
  122. image_targets = targets[targets[:, 0] == i]
  123. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  124. classes = image_targets[:, 1].astype('int')
  125. labels = image_targets.shape[1] == 6 # labels if no conf column
  126. conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
  127. if boxes.shape[1]:
  128. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  129. boxes[[0, 2]] *= w # scale to pixels
  130. boxes[[1, 3]] *= h
  131. elif scale_factor < 1: # absolute coords need scale if image scales
  132. boxes *= scale_factor
  133. boxes[[0, 2]] += block_x
  134. boxes[[1, 3]] += block_y
  135. for j, box in enumerate(boxes.T):
  136. cls = int(classes[j])
  137. color = colors[cls % len(colors)]
  138. cls = names[cls] if names else cls
  139. if labels or conf[j] > 0.25: # 0.25 conf thresh
  140. label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
  141. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  142. # Draw image filename labels
  143. if paths:
  144. label = Path(paths[i]).name[:40] # trim to 40 char
  145. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  146. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  147. lineType=cv2.LINE_AA)
  148. # Image border
  149. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  150. if fname:
  151. r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
  152. mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
  153. # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
  154. Image.fromarray(mosaic).save(fname) # PIL save
  155. return mosaic
  156. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  157. # Plot LR simulating training for full epochs
  158. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  159. y = []
  160. for _ in range(epochs):
  161. scheduler.step()
  162. y.append(optimizer.param_groups[0]['lr'])
  163. plt.plot(y, '.-', label='LR')
  164. plt.xlabel('epoch')
  165. plt.ylabel('LR')
  166. plt.grid()
  167. plt.xlim(0, epochs)
  168. plt.ylim(0)
  169. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  170. plt.close()
  171. def plot_test_txt(): # from utils.plots import *; plot_test()
  172. # Plot test.txt histograms
  173. x = np.loadtxt('test.txt', dtype=np.float32)
  174. box = xyxy2xywh(x[:, :4])
  175. cx, cy = box[:, 0], box[:, 1]
  176. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  177. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  178. ax.set_aspect('equal')
  179. plt.savefig('hist2d.png', dpi=300)
  180. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  181. ax[0].hist(cx, bins=600)
  182. ax[1].hist(cy, bins=600)
  183. plt.savefig('hist1d.png', dpi=200)
  184. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  185. # Plot targets.txt histograms
  186. x = np.loadtxt('targets.txt', dtype=np.float32).T
  187. s = ['x targets', 'y targets', 'width targets', 'height targets']
  188. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  189. ax = ax.ravel()
  190. for i in range(4):
  191. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  192. ax[i].legend()
  193. ax[i].set_title(s[i])
  194. plt.savefig('targets.jpg', dpi=200)
  195. def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
  196. # Plot study.txt generated by test.py
  197. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  198. # ax = ax.ravel()
  199. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  200. # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  201. for f in sorted(Path(path).glob('study*.txt')):
  202. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  203. x = np.arange(y.shape[1]) if x is None else np.array(x)
  204. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  205. # for i in range(7):
  206. # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  207. # ax[i].set_title(s[i])
  208. j = y[3].argmax() + 1
  209. ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
  210. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  211. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  212. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  213. ax2.grid(alpha=0.2)
  214. ax2.set_yticks(np.arange(20, 60, 5))
  215. ax2.set_xlim(0, 57)
  216. ax2.set_ylim(30, 55)
  217. ax2.set_xlabel('GPU Speed (ms/img)')
  218. ax2.set_ylabel('COCO AP val')
  219. ax2.legend(loc='lower right')
  220. plt.savefig(str(Path(path).name) + '.png', dpi=300)
  221. def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
  222. # plot dataset labels
  223. print('Plotting labels... ')
  224. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  225. nc = int(c.max() + 1) # number of classes
  226. colors = color_list()
  227. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  228. # seaborn correlogram
  229. sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  230. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  231. plt.close()
  232. # matplotlib labels
  233. matplotlib.use('svg') # faster
  234. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  235. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  236. ax[0].set_ylabel('instances')
  237. if 0 < len(names) < 30:
  238. ax[0].set_xticks(range(len(names)))
  239. ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  240. else:
  241. ax[0].set_xlabel('classes')
  242. sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  243. sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  244. # rectangles
  245. labels[:, 1:3] = 0.5 # center
  246. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  247. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  248. for cls, *box in labels[:1000]:
  249. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
  250. ax[1].imshow(img)
  251. ax[1].axis('off')
  252. for a in [0, 1, 2, 3]:
  253. for s in ['top', 'right', 'left', 'bottom']:
  254. ax[a].spines[s].set_visible(False)
  255. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  256. matplotlib.use('Agg')
  257. plt.close()
  258. # loggers
  259. for k, v in loggers.items() or {}:
  260. if k == 'wandb' and v:
  261. v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
  262. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
  263. # Plot hyperparameter evolution results in evolve.txt
  264. with open(yaml_file) as f:
  265. hyp = yaml.load(f, Loader=yaml.SafeLoader)
  266. x = np.loadtxt('evolve.txt', ndmin=2)
  267. f = fitness(x)
  268. # weights = (f - f.min()) ** 2 # for weighted results
  269. plt.figure(figsize=(10, 12), tight_layout=True)
  270. matplotlib.rc('font', **{'size': 8})
  271. for i, (k, v) in enumerate(hyp.items()):
  272. y = x[:, i + 7]
  273. # mu = (y * weights).sum() / weights.sum() # best weighted result
  274. mu = y[f.argmax()] # best single result
  275. plt.subplot(6, 5, i + 1)
  276. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  277. plt.plot(mu, f.max(), 'k+', markersize=15)
  278. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  279. if i % 5 != 0:
  280. plt.yticks([])
  281. print('%15s: %.3g' % (k, mu))
  282. plt.savefig('evolve.png', dpi=200)
  283. print('\nPlot saved as evolve.png')
  284. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  285. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  286. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  287. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  288. files = list(Path(save_dir).glob('frames*.txt'))
  289. for fi, f in enumerate(files):
  290. try:
  291. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  292. n = results.shape[1] # number of rows
  293. x = np.arange(start, min(stop, n) if stop else n)
  294. results = results[:, x]
  295. t = (results[0] - results[0].min()) # set t0=0s
  296. results[0] = x
  297. for i, a in enumerate(ax):
  298. if i < len(results):
  299. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  300. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  301. a.set_title(s[i])
  302. a.set_xlabel('time (s)')
  303. # if fi == len(files) - 1:
  304. # a.set_ylim(bottom=0)
  305. for side in ['top', 'right']:
  306. a.spines[side].set_visible(False)
  307. else:
  308. a.remove()
  309. except Exception as e:
  310. print('Warning: Plotting error for %s; %s' % (f, e))
  311. ax[1].legend()
  312. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  313. def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
  314. # Plot training 'results*.txt', overlaying train and val losses
  315. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  316. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  317. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  318. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  319. n = results.shape[1] # number of rows
  320. x = range(start, min(stop, n) if stop else n)
  321. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  322. ax = ax.ravel()
  323. for i in range(5):
  324. for j in [i, i + 5]:
  325. y = results[j, x]
  326. ax[i].plot(x, y, marker='.', label=s[j])
  327. # y_smooth = butter_lowpass_filtfilt(y)
  328. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  329. ax[i].set_title(t[i])
  330. ax[i].legend()
  331. ax[i].set_ylabel(f) if i == 0 else None # add filename
  332. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  333. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  334. # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
  335. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  336. ax = ax.ravel()
  337. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  338. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  339. if bucket:
  340. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  341. files = ['results%g.txt' % x for x in id]
  342. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  343. os.system(c)
  344. else:
  345. files = list(Path(save_dir).glob('results*.txt'))
  346. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  347. for fi, f in enumerate(files):
  348. try:
  349. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  350. n = results.shape[1] # number of rows
  351. x = range(start, min(stop, n) if stop else n)
  352. for i in range(10):
  353. y = results[i, x]
  354. if i in [0, 1, 2, 5, 6, 7]:
  355. y[y == 0] = np.nan # don't show zero loss values
  356. # y /= y[0] # normalize
  357. label = labels[fi] if len(labels) else f.stem
  358. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  359. ax[i].set_title(s[i])
  360. # if i in [5, 6, 7]: # share train and val loss y axes
  361. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  362. except Exception as e:
  363. print('Warning: Plotting error for %s; %s' % (f, e))
  364. ax[1].legend()
  365. fig.savefig(Path(save_dir) / 'results.png', dpi=200)