common.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import json
  6. import math
  7. import platform
  8. import warnings
  9. from collections import OrderedDict, namedtuple
  10. from copy import copy
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import requests
  16. import torch
  17. import torch.nn as nn
  18. import yaml
  19. from PIL import Image
  20. from torch.cuda import amp
  21. from utils.datasets import exif_transpose, letterbox
  22. from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
  23. make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
  24. from utils.plots import Annotator, colors, save_one_box
  25. from utils.torch_utils import copy_attr, time_sync
  26. def autopad(k, p=None): # kernel, padding
  27. # Pad to 'same'
  28. if p is None:
  29. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  30. return p
  31. class Conv(nn.Module):
  32. # Standard convolution
  33. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  34. super().__init__()
  35. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  36. self.bn = nn.BatchNorm2d(c2)
  37. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  38. def forward(self, x):
  39. return self.act(self.bn(self.conv(x)))
  40. def forward_fuse(self, x):
  41. return self.act(self.conv(x))
  42. class DWConv(Conv):
  43. # Depth-wise convolution class
  44. def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  45. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  46. class TransformerLayer(nn.Module):
  47. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  48. def __init__(self, c, num_heads):
  49. super().__init__()
  50. self.q = nn.Linear(c, c, bias=False)
  51. self.k = nn.Linear(c, c, bias=False)
  52. self.v = nn.Linear(c, c, bias=False)
  53. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  54. self.fc1 = nn.Linear(c, c, bias=False)
  55. self.fc2 = nn.Linear(c, c, bias=False)
  56. def forward(self, x):
  57. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  58. x = self.fc2(self.fc1(x)) + x
  59. return x
  60. class TransformerBlock(nn.Module):
  61. # Vision Transformer https://arxiv.org/abs/2010.11929
  62. def __init__(self, c1, c2, num_heads, num_layers):
  63. super().__init__()
  64. self.conv = None
  65. if c1 != c2:
  66. self.conv = Conv(c1, c2)
  67. self.linear = nn.Linear(c2, c2) # learnable position embedding
  68. self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
  69. self.c2 = c2
  70. def forward(self, x):
  71. if self.conv is not None:
  72. x = self.conv(x)
  73. b, _, w, h = x.shape
  74. p = x.flatten(2).permute(2, 0, 1)
  75. return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
  76. class Bottleneck(nn.Module):
  77. # Standard bottleneck
  78. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  79. super().__init__()
  80. c_ = int(c2 * e) # hidden channels
  81. self.cv1 = Conv(c1, c_, 1, 1)
  82. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  83. self.add = shortcut and c1 == c2
  84. def forward(self, x):
  85. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  86. class BottleneckCSP(nn.Module):
  87. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  88. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  89. super().__init__()
  90. c_ = int(c2 * e) # hidden channels
  91. self.cv1 = Conv(c1, c_, 1, 1)
  92. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  93. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  94. self.cv4 = Conv(2 * c_, c2, 1, 1)
  95. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  96. self.act = nn.SiLU()
  97. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  98. def forward(self, x):
  99. y1 = self.cv3(self.m(self.cv1(x)))
  100. y2 = self.cv2(x)
  101. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  102. class C3(nn.Module):
  103. # CSP Bottleneck with 3 convolutions
  104. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  105. super().__init__()
  106. c_ = int(c2 * e) # hidden channels
  107. self.cv1 = Conv(c1, c_, 1, 1)
  108. self.cv2 = Conv(c1, c_, 1, 1)
  109. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  110. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  111. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  112. def forward(self, x):
  113. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  114. class C3TR(C3):
  115. # C3 module with TransformerBlock()
  116. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  117. super().__init__(c1, c2, n, shortcut, g, e)
  118. c_ = int(c2 * e)
  119. self.m = TransformerBlock(c_, c_, 4, n)
  120. class C3SPP(C3):
  121. # C3 module with SPP()
  122. def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
  123. super().__init__(c1, c2, n, shortcut, g, e)
  124. c_ = int(c2 * e)
  125. self.m = SPP(c_, c_, k)
  126. class C3Ghost(C3):
  127. # C3 module with GhostBottleneck()
  128. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  129. super().__init__(c1, c2, n, shortcut, g, e)
  130. c_ = int(c2 * e) # hidden channels
  131. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  132. class SPP(nn.Module):
  133. # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
  134. def __init__(self, c1, c2, k=(5, 9, 13)):
  135. super().__init__()
  136. c_ = c1 // 2 # hidden channels
  137. self.cv1 = Conv(c1, c_, 1, 1)
  138. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  139. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  140. def forward(self, x):
  141. x = self.cv1(x)
  142. with warnings.catch_warnings():
  143. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  144. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  145. class SPPF(nn.Module):
  146. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  147. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  148. super().__init__()
  149. c_ = c1 // 2 # hidden channels
  150. self.cv1 = Conv(c1, c_, 1, 1)
  151. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  152. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  153. def forward(self, x):
  154. x = self.cv1(x)
  155. with warnings.catch_warnings():
  156. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  157. y1 = self.m(x)
  158. y2 = self.m(y1)
  159. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
  160. class Focus(nn.Module):
  161. # Focus wh information into c-space
  162. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  163. super().__init__()
  164. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  165. # self.contract = Contract(gain=2)
  166. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  167. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  168. # return self.conv(self.contract(x))
  169. class GhostConv(nn.Module):
  170. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  171. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  172. super().__init__()
  173. c_ = c2 // 2 # hidden channels
  174. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  175. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  176. def forward(self, x):
  177. y = self.cv1(x)
  178. return torch.cat([y, self.cv2(y)], 1)
  179. class GhostBottleneck(nn.Module):
  180. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  181. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  182. super().__init__()
  183. c_ = c2 // 2
  184. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  185. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  186. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  187. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  188. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  189. def forward(self, x):
  190. return self.conv(x) + self.shortcut(x)
  191. class Contract(nn.Module):
  192. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  193. def __init__(self, gain=2):
  194. super().__init__()
  195. self.gain = gain
  196. def forward(self, x):
  197. b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
  198. s = self.gain
  199. x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
  200. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  201. return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
  202. class Expand(nn.Module):
  203. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  204. def __init__(self, gain=2):
  205. super().__init__()
  206. self.gain = gain
  207. def forward(self, x):
  208. b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  209. s = self.gain
  210. x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
  211. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  212. return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
  213. class Concat(nn.Module):
  214. # Concatenate a list of tensors along dimension
  215. def __init__(self, dimension=1):
  216. super().__init__()
  217. self.d = dimension
  218. def forward(self, x):
  219. return torch.cat(x, self.d)
  220. class DetectMultiBackend(nn.Module):
  221. # YOLOv5 MultiBackend class for python inference on various backends
  222. def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
  223. # Usage:
  224. # PyTorch: weights = *.pt
  225. # TorchScript: *.torchscript
  226. # ONNX Runtime: *.onnx
  227. # ONNX OpenCV DNN: *.onnx with --dnn
  228. # OpenVINO: *.xml
  229. # CoreML: *.mlmodel
  230. # TensorRT: *.engine
  231. # TensorFlow SavedModel: *_saved_model
  232. # TensorFlow GraphDef: *.pb
  233. # TensorFlow Lite: *.tflite
  234. # TensorFlow Edge TPU: *_edgetpu.tflite
  235. from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
  236. super().__init__()
  237. w = str(weights[0] if isinstance(weights, list) else weights)
  238. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
  239. stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
  240. w = attempt_download(w) # download if not local
  241. if data: # data.yaml path (optional)
  242. with open(data, errors='ignore') as f:
  243. names = yaml.safe_load(f)['names'] # class names
  244. if pt: # PyTorch
  245. model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
  246. stride = max(int(model.stride.max()), 32) # model stride
  247. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  248. self.model = model # explicitly assign for to(), cpu(), cuda(), half()
  249. elif jit: # TorchScript
  250. LOGGER.info(f'Loading {w} for TorchScript inference...')
  251. extra_files = {'config.txt': ''} # model metadata
  252. model = torch.jit.load(w, _extra_files=extra_files)
  253. if extra_files['config.txt']:
  254. d = json.loads(extra_files['config.txt']) # extra_files dict
  255. stride, names = int(d['stride']), d['names']
  256. elif dnn: # ONNX OpenCV DNN
  257. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  258. check_requirements(('opencv-python>=4.5.4',))
  259. net = cv2.dnn.readNetFromONNX(w)
  260. elif onnx: # ONNX Runtime
  261. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  262. cuda = torch.cuda.is_available()
  263. check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
  264. import onnxruntime
  265. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  266. session = onnxruntime.InferenceSession(w, providers=providers)
  267. elif xml: # OpenVINO
  268. LOGGER.info(f'Loading {w} for OpenVINO inference...')
  269. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  270. import openvino.inference_engine as ie
  271. core = ie.IECore()
  272. if not Path(w).is_file(): # if not *.xml
  273. w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
  274. network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths
  275. executable_network = core.load_network(network, device_name='CPU', num_requests=1)
  276. elif engine: # TensorRT
  277. LOGGER.info(f'Loading {w} for TensorRT inference...')
  278. import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
  279. check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
  280. Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
  281. logger = trt.Logger(trt.Logger.INFO)
  282. with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
  283. model = runtime.deserialize_cuda_engine(f.read())
  284. bindings = OrderedDict()
  285. for index in range(model.num_bindings):
  286. name = model.get_binding_name(index)
  287. dtype = trt.nptype(model.get_binding_dtype(index))
  288. shape = tuple(model.get_binding_shape(index))
  289. data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
  290. bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
  291. binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
  292. context = model.create_execution_context()
  293. batch_size = bindings['images'].shape[0]
  294. elif coreml: # CoreML
  295. LOGGER.info(f'Loading {w} for CoreML inference...')
  296. import coremltools as ct
  297. model = ct.models.MLModel(w)
  298. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  299. if saved_model: # SavedModel
  300. LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
  301. import tensorflow as tf
  302. keras = False # assume TF1 saved_model
  303. model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
  304. elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  305. LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
  306. import tensorflow as tf
  307. def wrap_frozen_graph(gd, inputs, outputs):
  308. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
  309. ge = x.graph.as_graph_element
  310. return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
  311. gd = tf.Graph().as_graph_def() # graph_def
  312. gd.ParseFromString(open(w, 'rb').read())
  313. frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
  314. elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  315. try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
  316. from tflite_runtime.interpreter import Interpreter, load_delegate
  317. except ImportError:
  318. import tensorflow as tf
  319. Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
  320. if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
  321. LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
  322. delegate = {'Linux': 'libedgetpu.so.1',
  323. 'Darwin': 'libedgetpu.1.dylib',
  324. 'Windows': 'edgetpu.dll'}[platform.system()]
  325. interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
  326. else: # Lite
  327. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  328. interpreter = Interpreter(model_path=w) # load TFLite model
  329. interpreter.allocate_tensors() # allocate
  330. input_details = interpreter.get_input_details() # inputs
  331. output_details = interpreter.get_output_details() # outputs
  332. elif tfjs:
  333. raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
  334. self.__dict__.update(locals()) # assign all variables to self
  335. def forward(self, im, augment=False, visualize=False, val=False):
  336. # YOLOv5 MultiBackend inference
  337. b, ch, h, w = im.shape # batch, channel, height, width
  338. if self.pt or self.jit: # PyTorch
  339. y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
  340. return y if val else y[0]
  341. elif self.dnn: # ONNX OpenCV DNN
  342. im = im.cpu().numpy() # torch to numpy
  343. self.net.setInput(im)
  344. y = self.net.forward()
  345. elif self.onnx: # ONNX Runtime
  346. im = im.cpu().numpy() # torch to numpy
  347. y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
  348. elif self.xml: # OpenVINO
  349. im = im.cpu().numpy() # FP32
  350. desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW') # Tensor Description
  351. request = self.executable_network.requests[0] # inference request
  352. request.set_blob(blob_name='images', blob=self.ie.Blob(desc, im)) # name=next(iter(request.input_blobs))
  353. request.infer()
  354. y = request.output_blobs['output'].buffer # name=next(iter(request.output_blobs))
  355. elif self.engine: # TensorRT
  356. assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
  357. self.binding_addrs['images'] = int(im.data_ptr())
  358. self.context.execute_v2(list(self.binding_addrs.values()))
  359. y = self.bindings['output'].data
  360. elif self.coreml: # CoreML
  361. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  362. im = Image.fromarray((im[0] * 255).astype('uint8'))
  363. # im = im.resize((192, 320), Image.ANTIALIAS)
  364. y = self.model.predict({'image': im}) # coordinates are xywh normalized
  365. if 'confidence' in y:
  366. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  367. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  368. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  369. else:
  370. k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
  371. y = y[k] # output
  372. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  373. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  374. if self.saved_model: # SavedModel
  375. y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
  376. elif self.pb: # GraphDef
  377. y = self.frozen_func(x=self.tf.constant(im)).numpy()
  378. else: # Lite or Edge TPU
  379. input, output = self.input_details[0], self.output_details[0]
  380. int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
  381. if int8:
  382. scale, zero_point = input['quantization']
  383. im = (im / scale + zero_point).astype(np.uint8) # de-scale
  384. self.interpreter.set_tensor(input['index'], im)
  385. self.interpreter.invoke()
  386. y = self.interpreter.get_tensor(output['index'])
  387. if int8:
  388. scale, zero_point = output['quantization']
  389. y = (y.astype(np.float32) - zero_point) * scale # re-scale
  390. y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
  391. y = torch.tensor(y) if isinstance(y, np.ndarray) else y
  392. return (y, []) if val else y
  393. def warmup(self, imgsz=(1, 3, 640, 640), half=False):
  394. # Warmup model by running inference once
  395. if self.pt or self.jit or self.onnx or self.engine: # warmup types
  396. if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
  397. im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
  398. self.forward(im) # warmup
  399. @staticmethod
  400. def model_type(p='path/to/model.pt'):
  401. # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
  402. from export import export_formats
  403. suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
  404. check_suffix(p, suffixes) # checks
  405. p = Path(p).name # eliminate trailing separators
  406. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
  407. xml |= xml2 # *_openvino_model or *.xml
  408. tflite &= not edgetpu # *.tflite
  409. return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
  410. class DetectPrunedMultiBackend(nn.Module):
  411. # YOLOv5 MultiBackend class for python inference on various backends
  412. def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
  413. # Usage:
  414. # PyTorch: weights = *.pt
  415. # TorchScript: *.torchscript
  416. # ONNX Runtime: *.onnx
  417. # ONNX OpenCV DNN: *.onnx with --dnn
  418. # OpenVINO: *.xml
  419. # CoreML: *.mlmodel
  420. # TensorRT: *.engine
  421. from models.experimental import attempt_download, attempt_load_pruned # scoped to avoid circular import
  422. super().__init__()
  423. w = str(weights[0] if isinstance(weights, list) else weights)
  424. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
  425. stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
  426. w = attempt_download(w) # download if not local
  427. if data: # data.yaml path (optional)
  428. with open(data, errors='ignore') as f:
  429. names = yaml.safe_load(f)['names'] # class names
  430. if pt: # PyTorch
  431. model = attempt_load_pruned(weights if isinstance(weights, list) else w, map_location=device)
  432. stride = max(int(model.stride.max()), 32) # model stride
  433. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  434. self.model = model # explicitly assign for to(), cpu(), cuda(), half()
  435. elif jit: # TorchScript
  436. LOGGER.info(f'Loading {w} for TorchScript inference...')
  437. extra_files = {'config.txt': ''} # model metadata
  438. model = torch.jit.load(w, _extra_files=extra_files)
  439. if extra_files['config.txt']:
  440. d = json.loads(extra_files['config.txt']) # extra_files dict
  441. stride, names = int(d['stride']), d['names']
  442. elif dnn: # ONNX OpenCV DNN
  443. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  444. check_requirements(('opencv-python>=4.5.4',))
  445. net = cv2.dnn.readNetFromONNX(w)
  446. elif onnx: # ONNX Runtime
  447. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  448. cuda = torch.cuda.is_available()
  449. check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
  450. import onnxruntime
  451. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  452. session = onnxruntime.InferenceSession(w, providers=providers)
  453. elif xml: # OpenVINO
  454. LOGGER.info(f'Loading {w} for OpenVINO inference...')
  455. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  456. import openvino.inference_engine as ie
  457. core = ie.IECore()
  458. if not Path(w).is_file(): # if not *.xml
  459. w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
  460. network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths
  461. executable_network = core.load_network(network, device_name='CPU', num_requests=1)
  462. elif engine: # TensorRT
  463. LOGGER.info(f'Loading {w} for TensorRT inference...')
  464. import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
  465. check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
  466. Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
  467. logger = trt.Logger(trt.Logger.INFO)
  468. with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
  469. model = runtime.deserialize_cuda_engine(f.read())
  470. bindings = OrderedDict()
  471. for index in range(model.num_bindings):
  472. name = model.get_binding_name(index)
  473. dtype = trt.nptype(model.get_binding_dtype(index))
  474. shape = tuple(model.get_binding_shape(index))
  475. data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
  476. bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
  477. binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
  478. context = model.create_execution_context()
  479. batch_size = bindings['images'].shape[0]
  480. elif coreml: # CoreML
  481. LOGGER.info(f'Loading {w} for CoreML inference...')
  482. import coremltools as ct
  483. model = ct.models.MLModel(w)
  484. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  485. if saved_model: # SavedModel
  486. LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
  487. import tensorflow as tf
  488. keras = False # assume TF1 saved_model
  489. model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
  490. elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  491. LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
  492. import tensorflow as tf
  493. def wrap_frozen_graph(gd, inputs, outputs):
  494. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
  495. ge = x.graph.as_graph_element
  496. return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
  497. gd = tf.Graph().as_graph_def() # graph_def
  498. gd.ParseFromString(open(w, 'rb').read())
  499. frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
  500. elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  501. try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
  502. from tflite_runtime.interpreter import Interpreter, load_delegate
  503. except ImportError:
  504. import tensorflow as tf
  505. Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
  506. if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
  507. LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
  508. delegate = {'Linux': 'libedgetpu.so.1',
  509. 'Darwin': 'libedgetpu.1.dylib',
  510. 'Windows': 'edgetpu.dll'}[platform.system()]
  511. interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
  512. else: # Lite
  513. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  514. interpreter = Interpreter(model_path=w) # load TFLite model
  515. interpreter.allocate_tensors() # allocate
  516. input_details = interpreter.get_input_details() # inputs
  517. output_details = interpreter.get_output_details() # outputs
  518. elif tfjs:
  519. raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
  520. self.__dict__.update(locals()) # assign all variables to self
  521. def forward(self, im, augment=False, visualize=False, val=False):
  522. # YOLOv5 MultiBackend inference
  523. b, ch, h, w = im.shape # batch, channel, height, width
  524. if self.pt or self.jit: # PyTorch
  525. y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
  526. return y if val else y[0]
  527. elif self.dnn: # ONNX OpenCV DNN
  528. im = im.cpu().numpy() # torch to numpy
  529. self.net.setInput(im)
  530. y = self.net.forward()
  531. elif self.onnx: # ONNX Runtime
  532. im = im.cpu().numpy() # torch to numpy
  533. y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
  534. elif self.xml: # OpenVINO
  535. im = im.cpu().numpy() # FP32
  536. desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW') # Tensor Description
  537. request = self.executable_network.requests[0] # inference request
  538. request.set_blob(blob_name='images', blob=self.ie.Blob(desc, im)) # name=next(iter(request.input_blobs))
  539. request.infer()
  540. y = request.output_blobs['output'].buffer # name=next(iter(request.output_blobs))
  541. elif self.engine: # TensorRT
  542. assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
  543. self.binding_addrs['images'] = int(im.data_ptr())
  544. self.context.execute_v2(list(self.binding_addrs.values()))
  545. y = self.bindings['output'].data
  546. elif self.coreml: # CoreML
  547. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  548. im = Image.fromarray((im[0] * 255).astype('uint8'))
  549. # im = im.resize((192, 320), Image.ANTIALIAS)
  550. y = self.model.predict({'image': im}) # coordinates are xywh normalized
  551. if 'confidence' in y:
  552. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  553. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  554. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  555. else:
  556. k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
  557. y = y[k] # output
  558. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  559. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  560. if self.saved_model: # SavedModel
  561. y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
  562. elif self.pb: # GraphDef
  563. y = self.frozen_func(x=self.tf.constant(im)).numpy()
  564. else: # Lite or Edge TPU
  565. input, output = self.input_details[0], self.output_details[0]
  566. int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
  567. if int8:
  568. scale, zero_point = input['quantization']
  569. im = (im / scale + zero_point).astype(np.uint8) # de-scale
  570. self.interpreter.set_tensor(input['index'], im)
  571. self.interpreter.invoke()
  572. y = self.interpreter.get_tensor(output['index'])
  573. if int8:
  574. scale, zero_point = output['quantization']
  575. y = (y.astype(np.float32) - zero_point) * scale # re-scale
  576. y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
  577. y = torch.tensor(y) if isinstance(y, np.ndarray) else y
  578. return (y, []) if val else y
  579. def warmup(self, imgsz=(1, 3, 640, 640), half=False):
  580. # Warmup model by running inference once
  581. if self.pt or self.jit or self.onnx or self.engine: # warmup types
  582. if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
  583. im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
  584. self.forward(im) # warmup
  585. @staticmethod
  586. def model_type(p='path/to/model.pt'):
  587. # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
  588. from export import export_formats
  589. suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
  590. check_suffix(p, suffixes) # checks
  591. p = Path(p).name # eliminate trailing separators
  592. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
  593. xml |= xml2 # *_openvino_model or *.xml
  594. tflite &= not edgetpu # *.tflite
  595. return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
  596. class AutoShape(nn.Module):
  597. # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  598. conf = 0.25 # NMS confidence threshold
  599. iou = 0.45 # NMS IoU threshold
  600. agnostic = False # NMS class-agnostic
  601. multi_label = False # NMS multiple labels per box
  602. classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
  603. max_det = 1000 # maximum number of detections per image
  604. amp = False # Automatic Mixed Precision (AMP) inference
  605. def __init__(self, model):
  606. super().__init__()
  607. LOGGER.info('Adding AutoShape... ')
  608. copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
  609. self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
  610. self.pt = not self.dmb or model.pt # PyTorch model
  611. self.model = model.eval()
  612. def _apply(self, fn):
  613. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  614. self = super()._apply(fn)
  615. if self.pt:
  616. m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
  617. m.stride = fn(m.stride)
  618. m.grid = list(map(fn, m.grid))
  619. if isinstance(m.anchor_grid, list):
  620. m.anchor_grid = list(map(fn, m.anchor_grid))
  621. return self
  622. @torch.no_grad()
  623. def forward(self, imgs, size=640, augment=False, profile=False):
  624. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  625. # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
  626. # URI: = 'https://ultralytics.com/images/zidane.jpg'
  627. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  628. # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
  629. # numpy: = np.zeros((640,1280,3)) # HWC
  630. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  631. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  632. t = [time_sync()]
  633. p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
  634. autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
  635. if isinstance(imgs, torch.Tensor): # torch
  636. with amp.autocast(enabled=autocast):
  637. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  638. # Pre-process
  639. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  640. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  641. for i, im in enumerate(imgs):
  642. f = f'image{i}' # filename
  643. if isinstance(im, (str, Path)): # filename or uri
  644. im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
  645. im = np.asarray(exif_transpose(im))
  646. elif isinstance(im, Image.Image): # PIL Image
  647. im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
  648. files.append(Path(f).with_suffix('.jpg').name)
  649. if im.shape[0] < 5: # image in CHW
  650. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  651. im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
  652. s = im.shape[:2] # HWC
  653. shape0.append(s) # image shape
  654. g = (size / max(s)) # gain
  655. shape1.append([y * g for y in s])
  656. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  657. shape1 = [make_divisible(x, self.stride) for x in np.stack(shape1, 0).max(0)] # inference shape
  658. x = [letterbox(im, new_shape=shape1 if self.pt else size, auto=False)[0] for im in imgs] # pad
  659. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  660. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  661. x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
  662. t.append(time_sync())
  663. with amp.autocast(enabled=autocast):
  664. # Inference
  665. y = self.model(x, augment, profile) # forward
  666. t.append(time_sync())
  667. # Post-process
  668. y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes,
  669. agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det) # NMS
  670. for i in range(n):
  671. scale_coords(shape1, y[i][:, :4], shape0[i])
  672. t.append(time_sync())
  673. return Detections(imgs, y, files, t, self.names, x.shape)
  674. class Detections:
  675. # YOLOv5 detections class for inference results
  676. def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
  677. super().__init__()
  678. d = pred[0].device # device
  679. gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
  680. self.imgs = imgs # list of images as numpy arrays
  681. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  682. self.names = names # class names
  683. self.files = files # image filenames
  684. self.times = times # profiling times
  685. self.xyxy = pred # xyxy pixels
  686. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  687. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  688. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  689. self.n = len(self.pred) # number of images (batch size)
  690. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  691. self.s = shape # inference BCHW shape
  692. def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
  693. crops = []
  694. for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
  695. s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
  696. if pred.shape[0]:
  697. for c in pred[:, -1].unique():
  698. n = (pred[:, -1] == c).sum() # detections per class
  699. s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  700. if show or save or render or crop:
  701. annotator = Annotator(im, example=str(self.names))
  702. for *box, conf, cls in reversed(pred): # xyxy, confidence, class
  703. label = f'{self.names[int(cls)]} {conf:.2f}'
  704. if crop:
  705. file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
  706. crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
  707. 'im': save_one_box(box, im, file=file, save=save)})
  708. else: # all others
  709. annotator.box_label(box, label, color=colors(cls))
  710. im = annotator.im
  711. else:
  712. s += '(no detections)'
  713. im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
  714. if pprint:
  715. LOGGER.info(s.rstrip(', '))
  716. if show:
  717. im.show(self.files[i]) # show
  718. if save:
  719. f = self.files[i]
  720. im.save(save_dir / f) # save
  721. if i == self.n - 1:
  722. LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
  723. if render:
  724. self.imgs[i] = np.asarray(im)
  725. if crop:
  726. if save:
  727. LOGGER.info(f'Saved results to {save_dir}\n')
  728. return crops
  729. def print(self):
  730. self.display(pprint=True) # print results
  731. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  732. self.t)
  733. def show(self):
  734. self.display(show=True) # show results
  735. def save(self, save_dir='runs/detect/exp'):
  736. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
  737. self.display(save=True, save_dir=save_dir) # save results
  738. def crop(self, save=True, save_dir='runs/detect/exp'):
  739. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
  740. return self.display(crop=True, save=save, save_dir=save_dir) # crop results
  741. def render(self):
  742. self.display(render=True) # render results
  743. return self.imgs
  744. def pandas(self):
  745. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  746. new = copy(self) # return copy
  747. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  748. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  749. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  750. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  751. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  752. return new
  753. def tolist(self):
  754. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  755. r = range(self.n) # iterable
  756. x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
  757. # for d in x:
  758. # for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  759. # setattr(d, k, getattr(d, k)[0]) # pop out of list
  760. return x
  761. def __len__(self):
  762. return self.n
  763. class Classify(nn.Module):
  764. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  765. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  766. super().__init__()
  767. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  768. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  769. self.flat = nn.Flatten()
  770. def forward(self, x):
  771. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  772. return self.flat(self.conv(z)) # flatten to x(b,c2)