datasets.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064
  1. # Dataset utils and dataloaders
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. import time
  9. from itertools import repeat
  10. from multiprocessing.pool import ThreadPool
  11. from pathlib import Path
  12. from threading import Thread
  13. import cv2
  14. import numpy as np
  15. import torch
  16. import torch.nn.functional as F
  17. from PIL import Image, ExifTags
  18. from torch.utils.data import Dataset
  19. from tqdm import tqdm
  20. from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
  21. resample_segments, clean_str
  22. from utils.torch_utils import torch_distributed_zero_first
  23. # Parameters
  24. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  25. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  26. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  27. logger = logging.getLogger(__name__)
  28. # Get orientation exif tag
  29. for orientation in ExifTags.TAGS.keys():
  30. if ExifTags.TAGS[orientation] == 'Orientation':
  31. break
  32. def get_hash(files):
  33. # Returns a single hash value of a list of files
  34. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  35. def exif_size(img):
  36. # Returns exif-corrected PIL size
  37. s = img.size # (width, height)
  38. try:
  39. rotation = dict(img._getexif().items())[orientation]
  40. if rotation == 6: # rotation 270
  41. s = (s[1], s[0])
  42. elif rotation == 8: # rotation 90
  43. s = (s[1], s[0])
  44. except:
  45. pass
  46. return s
  47. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  48. rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  49. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  50. with torch_distributed_zero_first(rank):
  51. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  52. augment=augment, # augment images
  53. hyp=hyp, # augmentation hyperparameters
  54. rect=rect, # rectangular training
  55. cache_images=cache,
  56. single_cls=opt.single_cls,
  57. stride=int(stride),
  58. pad=pad,
  59. image_weights=image_weights,
  60. prefix=prefix)
  61. batch_size = min(batch_size, len(dataset))
  62. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  63. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  64. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  65. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  66. dataloader = loader(dataset,
  67. batch_size=batch_size,
  68. num_workers=nw,
  69. sampler=sampler,
  70. pin_memory=True,
  71. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  72. return dataloader, dataset
  73. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  74. """ Dataloader that reuses workers
  75. Uses same syntax as vanilla DataLoader
  76. """
  77. def __init__(self, *args, **kwargs):
  78. super().__init__(*args, **kwargs)
  79. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  80. self.iterator = super().__iter__()
  81. def __len__(self):
  82. return len(self.batch_sampler.sampler)
  83. def __iter__(self):
  84. for i in range(len(self)):
  85. yield next(self.iterator)
  86. class _RepeatSampler(object):
  87. """ Sampler that repeats forever
  88. Args:
  89. sampler (Sampler)
  90. """
  91. def __init__(self, sampler):
  92. self.sampler = sampler
  93. def __iter__(self):
  94. while True:
  95. yield from iter(self.sampler)
  96. class LoadImages: # for inference
  97. def __init__(self, path, img_size=640, stride=32):
  98. p = str(Path(path).absolute()) # os-agnostic absolute path
  99. if '*' in p:
  100. files = sorted(glob.glob(p, recursive=True)) # glob
  101. elif os.path.isdir(p):
  102. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  103. elif os.path.isfile(p):
  104. files = [p] # files
  105. else:
  106. raise Exception(f'ERROR: {p} does not exist')
  107. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  108. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  109. ni, nv = len(images), len(videos)
  110. self.img_size = img_size
  111. self.stride = stride
  112. self.files = images + videos
  113. self.nf = ni + nv # number of files
  114. self.video_flag = [False] * ni + [True] * nv
  115. self.mode = 'image'
  116. if any(videos):
  117. self.new_video(videos[0]) # new video
  118. else:
  119. self.cap = None
  120. assert self.nf > 0, f'No images or videos found in {p}. ' \
  121. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  122. def __iter__(self):
  123. self.count = 0
  124. return self
  125. def __next__(self):
  126. if self.count == self.nf:
  127. raise StopIteration
  128. path = self.files[self.count]
  129. if self.video_flag[self.count]:
  130. # Read video
  131. self.mode = 'video'
  132. ret_val, img0 = self.cap.read()
  133. if not ret_val:
  134. self.count += 1
  135. self.cap.release()
  136. if self.count == self.nf: # last video
  137. raise StopIteration
  138. else:
  139. path = self.files[self.count]
  140. self.new_video(path)
  141. ret_val, img0 = self.cap.read()
  142. self.frame += 1
  143. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
  144. else:
  145. # Read image
  146. self.count += 1
  147. img0 = cv2.imread(path) # BGR
  148. assert img0 is not None, 'Image Not Found ' + path
  149. print(f'image {self.count}/{self.nf} {path}: ', end='')
  150. # Padded resize
  151. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  152. # Convert
  153. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  154. img = np.ascontiguousarray(img)
  155. return path, img, img0, self.cap
  156. def new_video(self, path):
  157. self.frame = 0
  158. self.cap = cv2.VideoCapture(path)
  159. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  160. def __len__(self):
  161. return self.nf # number of files
  162. class LoadWebcam: # for inference
  163. def __init__(self, pipe='0', img_size=640, stride=32):
  164. self.img_size = img_size
  165. self.stride = stride
  166. if pipe.isnumeric():
  167. pipe = eval(pipe) # local camera
  168. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  169. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  170. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  171. self.pipe = pipe
  172. self.cap = cv2.VideoCapture(pipe) # video capture object
  173. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  174. def __iter__(self):
  175. self.count = -1
  176. return self
  177. def __next__(self):
  178. self.count += 1
  179. if cv2.waitKey(1) == ord('q'): # q to quit
  180. self.cap.release()
  181. cv2.destroyAllWindows()
  182. raise StopIteration
  183. # Read frame
  184. if self.pipe == 0: # local camera
  185. ret_val, img0 = self.cap.read()
  186. img0 = cv2.flip(img0, 1) # flip left-right
  187. else: # IP camera
  188. n = 0
  189. while True:
  190. n += 1
  191. self.cap.grab()
  192. if n % 30 == 0: # skip frames
  193. ret_val, img0 = self.cap.retrieve()
  194. if ret_val:
  195. break
  196. # Print
  197. assert ret_val, f'Camera Error {self.pipe}'
  198. img_path = 'webcam.jpg'
  199. print(f'webcam {self.count}: ', end='')
  200. # Padded resize
  201. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  202. # Convert
  203. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  204. img = np.ascontiguousarray(img)
  205. return img_path, img, img0, None
  206. def __len__(self):
  207. return 0
  208. class LoadStreams: # multiple IP or RTSP cameras
  209. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  210. self.mode = 'stream'
  211. self.img_size = img_size
  212. self.stride = stride
  213. if os.path.isfile(sources):
  214. with open(sources, 'r') as f:
  215. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  216. else:
  217. sources = [sources]
  218. n = len(sources)
  219. self.imgs = [None] * n
  220. self.sources = [clean_str(x) for x in sources] # clean source names for later
  221. for i, s in enumerate(sources):
  222. # Start the thread to read frames from the video stream
  223. print(f'{i + 1}/{n}: {s}... ', end='')
  224. url = eval(s) if s.isnumeric() else s
  225. if 'youtube.com/' in url or 'youtu.be/' in url: # if source is YouTube video
  226. check_requirements(('pafy', 'youtube_dl'))
  227. import pafy
  228. url = pafy.new(url).getbest(preftype="mp4").url
  229. cap = cv2.VideoCapture(url)
  230. assert cap.isOpened(), f'Failed to open {s}'
  231. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  232. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  233. self.fps = cap.get(cv2.CAP_PROP_FPS) % 100
  234. _, self.imgs[i] = cap.read() # guarantee first frame
  235. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  236. print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
  237. thread.start()
  238. print('') # newline
  239. # check for common shapes
  240. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  241. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  242. if not self.rect:
  243. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  244. def update(self, index, cap):
  245. # Read next stream frame in a daemon thread
  246. n = 0
  247. while cap.isOpened():
  248. n += 1
  249. # _, self.imgs[index] = cap.read()
  250. cap.grab()
  251. if n == 4: # read every 4th frame
  252. success, im = cap.retrieve()
  253. self.imgs[index] = im if success else self.imgs[index] * 0
  254. n = 0
  255. time.sleep(1 / self.fps) # wait time
  256. def __iter__(self):
  257. self.count = -1
  258. return self
  259. def __next__(self):
  260. self.count += 1
  261. img0 = self.imgs.copy()
  262. if cv2.waitKey(1) == ord('q'): # q to quit
  263. cv2.destroyAllWindows()
  264. raise StopIteration
  265. # Letterbox
  266. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  267. # Stack
  268. img = np.stack(img, 0)
  269. # Convert
  270. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  271. img = np.ascontiguousarray(img)
  272. return self.sources, img, img0, None
  273. def __len__(self):
  274. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  275. def img2label_paths(img_paths):
  276. # Define label paths as a function of image paths
  277. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  278. return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
  279. class LoadImagesAndLabels(Dataset): # for training/testing
  280. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  281. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  282. self.img_size = img_size
  283. self.augment = augment
  284. self.hyp = hyp
  285. self.image_weights = image_weights
  286. self.rect = False if image_weights else rect
  287. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  288. self.mosaic_border = [-img_size // 2, -img_size // 2]
  289. self.stride = stride
  290. self.path = path
  291. try:
  292. f = [] # image files
  293. for p in path if isinstance(path, list) else [path]:
  294. p = Path(p) # os-agnostic
  295. if p.is_dir(): # dir
  296. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  297. # f = list(p.rglob('**/*.*')) # pathlib
  298. elif p.is_file(): # file
  299. with open(p, 'r') as t:
  300. t = t.read().strip().splitlines()
  301. parent = str(p.parent) + os.sep
  302. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  303. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  304. else:
  305. raise Exception(f'{prefix}{p} does not exist')
  306. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  307. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  308. assert self.img_files, f'{prefix}No images found'
  309. except Exception as e:
  310. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  311. # Check cache
  312. self.label_files = img2label_paths(self.img_files) # labels
  313. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  314. if cache_path.is_file():
  315. cache, exists = torch.load(cache_path), True # load
  316. if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
  317. cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  318. else:
  319. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  320. # Display cache
  321. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  322. if exists:
  323. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  324. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  325. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  326. # Read cache
  327. cache.pop('hash') # remove hash
  328. cache.pop('version') # remove version
  329. labels, shapes, self.segments = zip(*cache.values())
  330. self.labels = list(labels)
  331. self.shapes = np.array(shapes, dtype=np.float64)
  332. self.img_files = list(cache.keys()) # update
  333. self.label_files = img2label_paths(cache.keys()) # update
  334. if single_cls:
  335. for x in self.labels:
  336. x[:, 0] = 0
  337. n = len(shapes) # number of images
  338. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  339. nb = bi[-1] + 1 # number of batches
  340. self.batch = bi # batch index of image
  341. self.n = n
  342. self.indices = range(n)
  343. # Rectangular Training
  344. if self.rect:
  345. # Sort by aspect ratio
  346. s = self.shapes # wh
  347. ar = s[:, 1] / s[:, 0] # aspect ratio
  348. irect = ar.argsort()
  349. self.img_files = [self.img_files[i] for i in irect]
  350. self.label_files = [self.label_files[i] for i in irect]
  351. self.labels = [self.labels[i] for i in irect]
  352. self.shapes = s[irect] # wh
  353. ar = ar[irect]
  354. # Set training image shapes
  355. shapes = [[1, 1]] * nb
  356. for i in range(nb):
  357. ari = ar[bi == i]
  358. mini, maxi = ari.min(), ari.max()
  359. if maxi < 1:
  360. shapes[i] = [maxi, 1]
  361. elif mini > 1:
  362. shapes[i] = [1, 1 / mini]
  363. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  364. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  365. self.imgs = [None] * n
  366. if cache_images:
  367. gb = 0 # Gigabytes of cached images
  368. self.img_hw0, self.img_hw = [None] * n, [None] * n
  369. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
  370. pbar = tqdm(enumerate(results), total=n)
  371. for i, x in pbar:
  372. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  373. gb += self.imgs[i].nbytes
  374. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  375. pbar.close()
  376. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  377. # Cache dataset labels, check images and read shapes
  378. x = {} # dict
  379. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
  380. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  381. for i, (im_file, lb_file) in enumerate(pbar):
  382. try:
  383. # verify images
  384. im = Image.open(im_file)
  385. im.verify() # PIL verify
  386. shape = exif_size(im) # image size
  387. segments = [] # instance segments
  388. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  389. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  390. # verify labels
  391. if os.path.isfile(lb_file):
  392. nf += 1 # label found
  393. with open(lb_file, 'r') as f:
  394. l = [x.split() for x in f.read().strip().splitlines()]
  395. if any([len(x) > 8 for x in l]): # is segment
  396. classes = np.array([x[0] for x in l], dtype=np.float32)
  397. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  398. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  399. l = np.array(l, dtype=np.float32)
  400. if len(l):
  401. assert l.shape[1] == 5, 'labels require 5 columns each'
  402. assert (l >= 0).all(), 'negative labels'
  403. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  404. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  405. else:
  406. ne += 1 # label empty
  407. l = np.zeros((0, 5), dtype=np.float32)
  408. else:
  409. nm += 1 # label missing
  410. l = np.zeros((0, 5), dtype=np.float32)
  411. x[im_file] = [l, shape, segments]
  412. except Exception as e:
  413. nc += 1
  414. print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  415. pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
  416. f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  417. pbar.close()
  418. if nf == 0:
  419. print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  420. x['hash'] = get_hash(self.label_files + self.img_files)
  421. x['results'] = nf, nm, ne, nc, i + 1
  422. x['version'] = 0.1 # cache version
  423. torch.save(x, path) # save for next time
  424. logging.info(f'{prefix}New cache created: {path}')
  425. return x
  426. def __len__(self):
  427. return len(self.img_files)
  428. # def __iter__(self):
  429. # self.count = -1
  430. # print('ran dataset iter')
  431. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  432. # return self
  433. def __getitem__(self, index):
  434. index = self.indices[index] # linear, shuffled, or image_weights
  435. hyp = self.hyp
  436. mosaic = self.mosaic and random.random() < hyp['mosaic']
  437. if mosaic:
  438. # Load mosaic
  439. img, labels = load_mosaic(self, index)
  440. shapes = None
  441. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  442. if random.random() < hyp['mixup']:
  443. img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
  444. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  445. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  446. labels = np.concatenate((labels, labels2), 0)
  447. else:
  448. # Load image
  449. img, (h0, w0), (h, w) = load_image(self, index)
  450. # Letterbox
  451. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  452. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  453. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  454. labels = self.labels[index].copy()
  455. if labels.size: # normalized xywh to pixel xyxy format
  456. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  457. if self.augment:
  458. # Augment imagespace
  459. if not mosaic:
  460. img, labels = random_perspective(img, labels,
  461. degrees=hyp['degrees'],
  462. translate=hyp['translate'],
  463. scale=hyp['scale'],
  464. shear=hyp['shear'],
  465. perspective=hyp['perspective'])
  466. # Augment colorspace
  467. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  468. # Apply cutouts
  469. # if random.random() < 0.9:
  470. # labels = cutout(img, labels)
  471. nL = len(labels) # number of labels
  472. if nL:
  473. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  474. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  475. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  476. if self.augment:
  477. # flip up-down
  478. if random.random() < hyp['flipud']:
  479. img = np.flipud(img)
  480. if nL:
  481. labels[:, 2] = 1 - labels[:, 2]
  482. # flip left-right
  483. if random.random() < hyp['fliplr']:
  484. img = np.fliplr(img)
  485. if nL:
  486. labels[:, 1] = 1 - labels[:, 1]
  487. labels_out = torch.zeros((nL, 6))
  488. if nL:
  489. labels_out[:, 1:] = torch.from_numpy(labels)
  490. # Convert
  491. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  492. img = np.ascontiguousarray(img)
  493. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  494. @staticmethod
  495. def collate_fn(batch):
  496. img, label, path, shapes = zip(*batch) # transposed
  497. for i, l in enumerate(label):
  498. l[:, 0] = i # add target image index for build_targets()
  499. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  500. @staticmethod
  501. def collate_fn4(batch):
  502. img, label, path, shapes = zip(*batch) # transposed
  503. n = len(shapes) // 4
  504. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  505. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  506. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  507. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  508. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  509. i *= 4
  510. if random.random() < 0.5:
  511. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  512. 0].type(img[i].type())
  513. l = label[i]
  514. else:
  515. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  516. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  517. img4.append(im)
  518. label4.append(l)
  519. for i, l in enumerate(label4):
  520. l[:, 0] = i # add target image index for build_targets()
  521. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  522. # Ancillary functions --------------------------------------------------------------------------------------------------
  523. def load_image(self, index):
  524. # loads 1 image from dataset, returns img, original hw, resized hw
  525. img = self.imgs[index]
  526. if img is None: # not cached
  527. path = self.img_files[index]
  528. img = cv2.imread(path) # BGR
  529. assert img is not None, 'Image Not Found ' + path
  530. h0, w0 = img.shape[:2] # orig hw
  531. r = self.img_size / max(h0, w0) # resize image to img_size
  532. if r != 1: # always resize down, only resize up if training with augmentation
  533. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  534. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  535. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  536. else:
  537. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  538. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  539. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  540. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  541. dtype = img.dtype # uint8
  542. x = np.arange(0, 256, dtype=np.int16)
  543. lut_hue = ((x * r[0]) % 180).astype(dtype)
  544. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  545. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  546. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  547. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  548. def hist_equalize(img, clahe=True, bgr=False):
  549. # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
  550. yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  551. if clahe:
  552. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  553. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  554. else:
  555. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  556. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  557. def load_mosaic(self, index):
  558. # loads images in a 4-mosaic
  559. labels4, segments4 = [], []
  560. s = self.img_size
  561. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  562. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  563. for i, index in enumerate(indices):
  564. # Load image
  565. img, _, (h, w) = load_image(self, index)
  566. # place img in img4
  567. if i == 0: # top left
  568. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  569. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  570. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  571. elif i == 1: # top right
  572. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  573. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  574. elif i == 2: # bottom left
  575. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  576. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  577. elif i == 3: # bottom right
  578. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  579. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  580. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  581. padw = x1a - x1b
  582. padh = y1a - y1b
  583. # Labels
  584. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  585. if labels.size:
  586. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  587. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  588. labels4.append(labels)
  589. segments4.extend(segments)
  590. # Concat/clip labels
  591. labels4 = np.concatenate(labels4, 0)
  592. for x in (labels4[:, 1:], *segments4):
  593. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  594. # img4, labels4 = replicate(img4, labels4) # replicate
  595. # Augment
  596. img4, labels4 = random_perspective(img4, labels4, segments4,
  597. degrees=self.hyp['degrees'],
  598. translate=self.hyp['translate'],
  599. scale=self.hyp['scale'],
  600. shear=self.hyp['shear'],
  601. perspective=self.hyp['perspective'],
  602. border=self.mosaic_border) # border to remove
  603. return img4, labels4
  604. def load_mosaic9(self, index):
  605. # loads images in a 9-mosaic
  606. labels9, segments9 = [], []
  607. s = self.img_size
  608. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  609. for i, index in enumerate(indices):
  610. # Load image
  611. img, _, (h, w) = load_image(self, index)
  612. # place img in img9
  613. if i == 0: # center
  614. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  615. h0, w0 = h, w
  616. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  617. elif i == 1: # top
  618. c = s, s - h, s + w, s
  619. elif i == 2: # top right
  620. c = s + wp, s - h, s + wp + w, s
  621. elif i == 3: # right
  622. c = s + w0, s, s + w0 + w, s + h
  623. elif i == 4: # bottom right
  624. c = s + w0, s + hp, s + w0 + w, s + hp + h
  625. elif i == 5: # bottom
  626. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  627. elif i == 6: # bottom left
  628. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  629. elif i == 7: # left
  630. c = s - w, s + h0 - h, s, s + h0
  631. elif i == 8: # top left
  632. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  633. padx, pady = c[:2]
  634. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  635. # Labels
  636. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  637. if labels.size:
  638. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  639. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  640. labels9.append(labels)
  641. segments9.extend(segments)
  642. # Image
  643. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  644. hp, wp = h, w # height, width previous
  645. # Offset
  646. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  647. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  648. # Concat/clip labels
  649. labels9 = np.concatenate(labels9, 0)
  650. labels9[:, [1, 3]] -= xc
  651. labels9[:, [2, 4]] -= yc
  652. c = np.array([xc, yc]) # centers
  653. segments9 = [x - c for x in segments9]
  654. for x in (labels9[:, 1:], *segments9):
  655. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  656. # img9, labels9 = replicate(img9, labels9) # replicate
  657. # Augment
  658. img9, labels9 = random_perspective(img9, labels9, segments9,
  659. degrees=self.hyp['degrees'],
  660. translate=self.hyp['translate'],
  661. scale=self.hyp['scale'],
  662. shear=self.hyp['shear'],
  663. perspective=self.hyp['perspective'],
  664. border=self.mosaic_border) # border to remove
  665. return img9, labels9
  666. def replicate(img, labels):
  667. # Replicate labels
  668. h, w = img.shape[:2]
  669. boxes = labels[:, 1:].astype(int)
  670. x1, y1, x2, y2 = boxes.T
  671. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  672. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  673. x1b, y1b, x2b, y2b = boxes[i]
  674. bh, bw = y2b - y1b, x2b - x1b
  675. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  676. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  677. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  678. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  679. return img, labels
  680. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  681. # Resize and pad image while meeting stride-multiple constraints
  682. shape = img.shape[:2] # current shape [height, width]
  683. if isinstance(new_shape, int):
  684. new_shape = (new_shape, new_shape)
  685. # Scale ratio (new / old)
  686. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  687. if not scaleup: # only scale down, do not scale up (for better test mAP)
  688. r = min(r, 1.0)
  689. # Compute padding
  690. ratio = r, r # width, height ratios
  691. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  692. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  693. if auto: # minimum rectangle
  694. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  695. elif scaleFill: # stretch
  696. dw, dh = 0.0, 0.0
  697. new_unpad = (new_shape[1], new_shape[0])
  698. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  699. dw /= 2 # divide padding into 2 sides
  700. dh /= 2
  701. if shape[::-1] != new_unpad: # resize
  702. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  703. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  704. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  705. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  706. return img, ratio, (dw, dh)
  707. def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  708. border=(0, 0)):
  709. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  710. # targets = [cls, xyxy]
  711. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  712. width = img.shape[1] + border[1] * 2
  713. # Center
  714. C = np.eye(3)
  715. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  716. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  717. # Perspective
  718. P = np.eye(3)
  719. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  720. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  721. # Rotation and Scale
  722. R = np.eye(3)
  723. a = random.uniform(-degrees, degrees)
  724. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  725. s = random.uniform(1 - scale, 1 + scale)
  726. # s = 2 ** random.uniform(-scale, scale)
  727. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  728. # Shear
  729. S = np.eye(3)
  730. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  731. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  732. # Translation
  733. T = np.eye(3)
  734. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  735. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  736. # Combined rotation matrix
  737. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  738. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  739. if perspective:
  740. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  741. else: # affine
  742. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  743. # Visualize
  744. # import matplotlib.pyplot as plt
  745. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  746. # ax[0].imshow(img[:, :, ::-1]) # base
  747. # ax[1].imshow(img2[:, :, ::-1]) # warped
  748. # Transform label coordinates
  749. n = len(targets)
  750. if n:
  751. use_segments = any(x.any() for x in segments)
  752. new = np.zeros((n, 4))
  753. if use_segments: # warp segments
  754. segments = resample_segments(segments) # upsample
  755. for i, segment in enumerate(segments):
  756. xy = np.ones((len(segment), 3))
  757. xy[:, :2] = segment
  758. xy = xy @ M.T # transform
  759. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  760. # clip
  761. new[i] = segment2box(xy, width, height)
  762. else: # warp boxes
  763. xy = np.ones((n * 4, 3))
  764. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  765. xy = xy @ M.T # transform
  766. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  767. # create new boxes
  768. x = xy[:, [0, 2, 4, 6]]
  769. y = xy[:, [1, 3, 5, 7]]
  770. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  771. # clip
  772. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  773. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  774. # filter candidates
  775. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  776. targets = targets[i]
  777. targets[:, 1:5] = new[i]
  778. return img, targets
  779. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  780. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  781. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  782. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  783. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  784. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  785. def cutout(image, labels):
  786. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  787. h, w = image.shape[:2]
  788. def bbox_ioa(box1, box2):
  789. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  790. box2 = box2.transpose()
  791. # Get the coordinates of bounding boxes
  792. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  793. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  794. # Intersection area
  795. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  796. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  797. # box2 area
  798. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  799. # Intersection over box2 area
  800. return inter_area / box2_area
  801. # create random masks
  802. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  803. for s in scales:
  804. mask_h = random.randint(1, int(h * s))
  805. mask_w = random.randint(1, int(w * s))
  806. # box
  807. xmin = max(0, random.randint(0, w) - mask_w // 2)
  808. ymin = max(0, random.randint(0, h) - mask_h // 2)
  809. xmax = min(w, xmin + mask_w)
  810. ymax = min(h, ymin + mask_h)
  811. # apply random color mask
  812. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  813. # return unobscured labels
  814. if len(labels) and s > 0.03:
  815. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  816. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  817. labels = labels[ioa < 0.60] # remove >60% obscured labels
  818. return labels
  819. def create_folder(path='./new'):
  820. # Create folder
  821. if os.path.exists(path):
  822. shutil.rmtree(path) # delete output folder
  823. os.makedirs(path) # make new output folder
  824. def flatten_recursive(path='../coco128'):
  825. # Flatten a recursive directory by bringing all files to top level
  826. new_path = Path(path + '_flat')
  827. create_folder(new_path)
  828. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  829. shutil.copyfile(file, new_path / Path(file).name)
  830. def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
  831. # Convert detection dataset into classification dataset, with one directory per class
  832. path = Path(path) # images dir
  833. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  834. files = list(path.rglob('*.*'))
  835. n = len(files) # number of files
  836. for im_file in tqdm(files, total=n):
  837. if im_file.suffix[1:] in img_formats:
  838. # image
  839. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  840. h, w = im.shape[:2]
  841. # labels
  842. lb_file = Path(img2label_paths([str(im_file)])[0])
  843. if Path(lb_file).exists():
  844. with open(lb_file, 'r') as f:
  845. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  846. for j, x in enumerate(lb):
  847. c = int(x[0]) # class
  848. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  849. if not f.parent.is_dir():
  850. f.parent.mkdir(parents=True)
  851. b = x[1:] * [w, h, w, h] # box
  852. # b[2:] = b[2:].max() # rectangle to square
  853. b[2:] = b[2:] * 1.2 + 3 # pad
  854. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  855. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  856. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  857. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  858. def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
  859. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  860. Usage: from utils.datasets import *; autosplit('../coco128')
  861. Arguments
  862. path: Path to images directory
  863. weights: Train, val, test weights (list)
  864. annotated_only: Only use images with an annotated txt file
  865. """
  866. path = Path(path) # images dir
  867. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  868. n = len(files) # number of files
  869. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  870. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  871. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  872. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  873. for i, img in tqdm(zip(indices, files), total=n):
  874. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  875. with open(path / txt[i], 'a') as f:
  876. f.write(str(img) + '\n') # add image to txt file