experimental.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import math
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from models.common import Conv
  10. from utils.downloads import attempt_download
  11. class CrossConv(nn.Module):
  12. # Cross Convolution Downsample
  13. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  14. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  15. super().__init__()
  16. c_ = int(c2 * e) # hidden channels
  17. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  18. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  19. self.add = shortcut and c1 == c2
  20. def forward(self, x):
  21. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  22. class Sum(nn.Module):
  23. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  24. def __init__(self, n, weight=False): # n: number of inputs
  25. super().__init__()
  26. self.weight = weight # apply weights boolean
  27. self.iter = range(n - 1) # iter object
  28. if weight:
  29. self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
  30. def forward(self, x):
  31. y = x[0] # no weight
  32. if self.weight:
  33. w = torch.sigmoid(self.w) * 2
  34. for i in self.iter:
  35. y = y + x[i + 1] * w[i]
  36. else:
  37. for i in self.iter:
  38. y = y + x[i + 1]
  39. return y
  40. class MixConv2d(nn.Module):
  41. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  42. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
  43. super().__init__()
  44. n = len(k) # number of convolutions
  45. if equal_ch: # equal c_ per group
  46. i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
  47. c_ = [(i == g).sum() for g in range(n)] # intermediate channels
  48. else: # equal weight.numel() per group
  49. b = [c2] + [0] * n
  50. a = np.eye(n + 1, n, k=-1)
  51. a -= np.roll(a, 1, axis=1)
  52. a *= np.array(k) ** 2
  53. a[0] = 1
  54. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  55. self.m = nn.ModuleList(
  56. [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
  57. self.bn = nn.BatchNorm2d(c2)
  58. self.act = nn.SiLU()
  59. def forward(self, x):
  60. return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  61. class Ensemble(nn.ModuleList):
  62. # Ensemble of models
  63. def __init__(self):
  64. super().__init__()
  65. def forward(self, x, augment=False, profile=False, visualize=False):
  66. y = []
  67. for module in self:
  68. y.append(module(x, augment, profile, visualize)[0])
  69. # y = torch.stack(y).max(0)[0] # max ensemble
  70. # y = torch.stack(y).mean(0) # mean ensemble
  71. y = torch.cat(y, 1) # nms ensemble
  72. return y, None # inference, train output
  73. def attempt_load(weights, map_location=None, inplace=True, fuse=True):
  74. from models.yolo import Detect, Model
  75. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  76. model = Ensemble()
  77. for w in weights if isinstance(weights, list) else [weights]:
  78. ckpt = torch.load(attempt_download(w), map_location=map_location) # load
  79. if fuse:
  80. # model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  81. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # no fuse for bn prune
  82. else:
  83. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
  84. # Compatibility updates
  85. for m in model.modules():
  86. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  87. m.inplace = inplace # pytorch 1.7.0 compatibility
  88. if type(m) is Detect:
  89. if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
  90. delattr(m, 'anchor_grid')
  91. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  92. elif type(m) is Conv:
  93. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  94. if len(model) == 1:
  95. return model[-1] # return model
  96. else:
  97. print(f'Ensemble created with {weights}\n')
  98. for k in ['names']:
  99. setattr(model, k, getattr(model[-1], k))
  100. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  101. return model # return ensemble
  102. def attempt_load_pruned(weights, map_location=None, inplace=True, fuse=True):
  103. from models.yolo import Detect, ModelPruned
  104. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  105. model = Ensemble()
  106. for w in weights if isinstance(weights, list) else [weights]:
  107. ckpt = torch.load(attempt_download(w), map_location=map_location) # load
  108. if fuse:
  109. # model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  110. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # no fuse for bn prune
  111. else:
  112. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
  113. # Compatibility updates
  114. for m in model.modules():
  115. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, ModelPruned]:
  116. m.inplace = inplace # pytorch 1.7.0 compatibility
  117. if type(m) is Detect:
  118. if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
  119. delattr(m, 'anchor_grid')
  120. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  121. elif type(m) is Conv:
  122. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  123. if len(model) == 1:
  124. return model[-1] # return model
  125. else:
  126. print(f'Ensemble created with {weights}\n')
  127. for k in ['names']:
  128. setattr(model, k, getattr(model[-1], k))
  129. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  130. return model # return ensemble