yolo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # YOLOv5 YOLO-specific modules
  2. import argparse
  3. import logging
  4. import sys
  5. from copy import deepcopy
  6. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  7. logger = logging.getLogger(__name__)
  8. from models.common import *
  9. from models.experimental import *
  10. from utils.autoanchor import check_anchor_order
  11. from utils.general import make_divisible, check_file, set_logging
  12. from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
  13. select_device, copy_attr
  14. try:
  15. import thop # for FLOPS computation
  16. except ImportError:
  17. thop = None
  18. class Detect(nn.Module):
  19. stride = None # strides computed during build
  20. export = False # onnx export
  21. def __init__(self, nc=80, anchors=(), ch=()): # detection layer
  22. super(Detect, self).__init__()
  23. self.nc = nc # number of classes
  24. self.no = nc + 5 # number of outputs per anchor
  25. self.nl = len(anchors) # number of detection layers
  26. self.na = len(anchors[0]) // 2 # number of anchors
  27. self.grid = [torch.zeros(1)] * self.nl # init grid
  28. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  29. self.register_buffer('anchors', a) # shape(nl,na,2)
  30. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  31. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  32. def forward(self, x):
  33. # x = x.copy() # for profiling
  34. z = [] # inference output
  35. self.training |= self.export
  36. for i in range(self.nl):
  37. x[i] = self.m[i](x[i]) # conv
  38. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  39. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  40. if not self.training: # inference
  41. if self.grid[i].shape[2:4] != x[i].shape[2:4]:
  42. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  43. y = x[i].sigmoid()
  44. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  45. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  46. z.append(y.view(bs, -1, self.no))
  47. return x if self.training else (torch.cat(z, 1), x)
  48. @staticmethod
  49. def _make_grid(nx=20, ny=20):
  50. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  51. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  52. class Model(nn.Module):
  53. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  54. super(Model, self).__init__()
  55. if isinstance(cfg, dict):
  56. self.yaml = cfg # model dict
  57. else: # is *.yaml
  58. import yaml # for torch hub
  59. self.yaml_file = Path(cfg).name
  60. with open(cfg) as f:
  61. self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
  62. # Define model
  63. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  64. if nc and nc != self.yaml['nc']:
  65. logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  66. self.yaml['nc'] = nc # override yaml value
  67. if anchors:
  68. logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
  69. self.yaml['anchors'] = round(anchors) # override yaml value
  70. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  71. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  72. # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  73. # Build strides, anchors
  74. m = self.model[-1] # Detect()
  75. if isinstance(m, Detect):
  76. s = 256 # 2x min stride
  77. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  78. m.anchors /= m.stride.view(-1, 1, 1)
  79. check_anchor_order(m)
  80. self.stride = m.stride
  81. self._initialize_biases() # only run once
  82. # print('Strides: %s' % m.stride.tolist())
  83. # Init weights, biases
  84. initialize_weights(self)
  85. self.info()
  86. logger.info('')
  87. def forward(self, x, augment=False, profile=False):
  88. if augment:
  89. img_size = x.shape[-2:] # height, width
  90. s = [1, 0.83, 0.67] # scales
  91. f = [None, 3, None] # flips (2-ud, 3-lr)
  92. y = [] # outputs
  93. for si, fi in zip(s, f):
  94. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  95. yi = self.forward_once(xi)[0] # forward
  96. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  97. yi[..., :4] /= si # de-scale
  98. if fi == 2:
  99. yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
  100. elif fi == 3:
  101. yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
  102. y.append(yi)
  103. return torch.cat(y, 1), None # augmented inference, train
  104. else:
  105. return self.forward_once(x, profile) # single-scale inference, train
  106. def forward_once(self, x, profile=False):
  107. y, dt = [], [] # outputs
  108. for m in self.model:
  109. if m.f != -1: # if not from previous layer
  110. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  111. if profile:
  112. o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
  113. t = time_synchronized()
  114. for _ in range(10):
  115. _ = m(x)
  116. dt.append((time_synchronized() - t) * 100)
  117. print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
  118. x = m(x) # run
  119. y.append(x if m.i in self.save else None) # save output
  120. if profile:
  121. print('%.1fms total' % sum(dt))
  122. return x
  123. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  124. # https://arxiv.org/abs/1708.02002 section 3.3
  125. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  126. m = self.model[-1] # Detect() module
  127. for mi, s in zip(m.m, m.stride): # from
  128. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  129. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  130. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  131. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  132. def _print_biases(self):
  133. m = self.model[-1] # Detect() module
  134. for mi in m.m: # from
  135. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  136. print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  137. # def _print_weights(self):
  138. # for m in self.model.modules():
  139. # if type(m) is Bottleneck:
  140. # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  141. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  142. print('Fusing layers... ')
  143. for m in self.model.modules():
  144. if type(m) is Conv and hasattr(m, 'bn'):
  145. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  146. delattr(m, 'bn') # remove batchnorm
  147. m.forward = m.fuseforward # update forward
  148. self.info()
  149. return self
  150. def nms(self, mode=True): # add or remove NMS module
  151. present = type(self.model[-1]) is NMS # last layer is NMS
  152. if mode and not present:
  153. print('Adding NMS... ')
  154. m = NMS() # module
  155. m.f = -1 # from
  156. m.i = self.model[-1].i + 1 # index
  157. self.model.add_module(name='%s' % m.i, module=m) # add
  158. self.eval()
  159. elif not mode and present:
  160. print('Removing NMS... ')
  161. self.model = self.model[:-1] # remove
  162. return self
  163. def autoshape(self): # add autoShape module
  164. print('Adding autoShape... ')
  165. m = autoShape(self) # wrap model
  166. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  167. return m
  168. def info(self, verbose=False, img_size=640): # print model information
  169. model_info(self, verbose, img_size)
  170. def parse_model(d, ch): # model_dict, input_channels(3)
  171. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  172. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  173. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  174. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  175. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  176. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  177. m = eval(m) if isinstance(m, str) else m # eval strings
  178. for j, a in enumerate(args):
  179. try:
  180. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  181. except:
  182. pass
  183. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  184. if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
  185. C3, C3TR]:
  186. c1, c2 = ch[f], args[0]
  187. if c2 != no: # if not output
  188. c2 = make_divisible(c2 * gw, 8)
  189. args = [c1, c2, *args[1:]]
  190. if m in [BottleneckCSP, C3, C3TR]:
  191. args.insert(2, n) # number of repeats
  192. n = 1
  193. elif m is nn.BatchNorm2d:
  194. args = [ch[f]]
  195. elif m is Concat:
  196. c2 = sum([ch[x] for x in f])
  197. elif m is Detect:
  198. args.append([ch[x] for x in f])
  199. if isinstance(args[1], int): # number of anchors
  200. args[1] = [list(range(args[1] * 2))] * len(f)
  201. elif m is Contract:
  202. c2 = ch[f] * args[0] ** 2
  203. elif m is Expand:
  204. c2 = ch[f] // args[0] ** 2
  205. else:
  206. c2 = ch[f]
  207. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  208. t = str(m)[8:-2].replace('__main__.', '') # module type
  209. np = sum([x.numel() for x in m_.parameters()]) # number params
  210. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  211. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  212. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  213. layers.append(m_)
  214. if i == 0:
  215. ch = []
  216. ch.append(c2)
  217. return nn.Sequential(*layers), sorted(save)
  218. if __name__ == '__main__':
  219. parser = argparse.ArgumentParser()
  220. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  221. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  222. opt = parser.parse_args()
  223. opt.cfg = check_file(opt.cfg) # check file
  224. set_logging()
  225. device = select_device(opt.device)
  226. # Create model
  227. model = Model(opt.cfg).to(device)
  228. model.train()
  229. # Profile
  230. # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  231. # y = model(img, profile=True)
  232. # Tensorboard
  233. # from torch.utils.tensorboard import SummaryWriter
  234. # tb_writer = SummaryWriter()
  235. # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
  236. # tb_writer.add_graph(model.model, img) # add model to tensorboard
  237. # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard