layer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. import torch
  2. class cbs(torch.nn.Module):
  3. def __init__(self, in_, out_, kernel_size, stride):
  4. super().__init__()
  5. self.conv = torch.nn.Conv2d(in_, out_, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
  6. bias=False)
  7. self.bn = torch.nn.BatchNorm2d(out_, eps=0.001, momentum=0.03)
  8. self.silu = torch.nn.SiLU(inplace=True)
  9. def forward(self, x):
  10. x = self.conv(x)
  11. x = self.bn(x)
  12. x = self.silu(x)
  13. return x
  14. class concat(torch.nn.Module):
  15. def __init__(self, dim=1):
  16. super().__init__()
  17. self.concat = torch.concat
  18. self.dim = dim
  19. def forward(self, x):
  20. x = self.concat(x, dim=self.dim)
  21. return x
  22. class residual(torch.nn.Module): # in_->in_,len->len
  23. def __init__(self, in_, config=None):
  24. super().__init__()
  25. if not config: # 正常版本
  26. self.cbs0 = cbs(in_, in_, kernel_size=1, stride=1)
  27. self.cbs1 = cbs(in_, in_, kernel_size=3, stride=1)
  28. else: # 剪枝版本。len(config) = 2
  29. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  30. self.cbs1 = cbs(config[0], config[1], kernel_size=3, stride=1)
  31. def forward(self, x):
  32. x0 = self.cbs0(x)
  33. x0 = self.cbs1(x0)
  34. return x + x0
  35. class c3(torch.nn.Module): # in_->out_,len->len
  36. def __init__(self, in_, out_, n, config=None):
  37. super().__init__()
  38. if not config: # 正常版本
  39. self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  40. self.sequential1 = torch.nn.Sequential(*(residual(in_ // 2) for _ in range(n)))
  41. self.cbs2 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  42. self.concat3 = concat(dim=1)
  43. self.cbs4 = cbs(in_, out_, kernel_size=1, stride=1)
  44. else: # 剪枝版本。len(config) = 3 + 2 * n
  45. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  46. self.sequential1 = torch.nn.Sequential(
  47. *(residual(config[0 + 2 * _] if _ == 0 else config[1 + 2 * _] + config[2 + 2 * _],
  48. config[1 + 2 * _:3 + 2 * _]) for _ in range(n)))
  49. self.cbs2 = cbs(config[0], config[1 + 2 * n], kernel_size=1, stride=1)
  50. self.concat3 = concat(dim=1)
  51. self.cbs4 = cbs(config[0] + config[2 * n - 1] + config[2 * n] + config[1 + 2 * n], config[2 + 2 * n],
  52. kernel_size=1, stride=1)
  53. def forward(self, x):
  54. x0 = self.cbs0(x)
  55. x1 = self.sequential1(x0)
  56. x1 = x0 + x1
  57. x2 = self.cbs2(x)
  58. x = self.concat3([x1, x2])
  59. x = self.cbs4(x)
  60. return x
  61. class elan(torch.nn.Module): # in_->out_,len->len
  62. def __init__(self, in_, out_, n, config=None):
  63. super().__init__()
  64. if not config: # 正常版本
  65. self.cbs0 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
  66. self.cbs1 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
  67. self.sequential2 = torch.nn.Sequential(
  68. *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
  69. self.sequential3 = torch.nn.Sequential(
  70. *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
  71. self.concat4 = concat()
  72. self.cbs5 = cbs(out_, out_, kernel_size=1, stride=1)
  73. else: # 剪枝版本。len(config) = 3 + 2 * n
  74. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  75. self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
  76. self.sequential2 = torch.nn.Sequential(
  77. *(cbs(config[1 + _], config[2 + _], kernel_size=3, stride=1) for _ in range(n)))
  78. self.sequential3 = torch.nn.Sequential(
  79. *(cbs(config[1 + n + _], config[2 + n + _], kernel_size=3, stride=1) for _ in range(n)))
  80. self.concat4 = concat()
  81. self.cbs5 = cbs(config[0] + config[1] + config[1 + n] + config[1 + 2 * n], config[2 + 2 * n],
  82. kernel_size=1, stride=1)
  83. def forward(self, x):
  84. x0 = self.cbs0(x)
  85. x1 = self.cbs1(x)
  86. x2 = self.sequential2(x1)
  87. x3 = self.sequential3(x2)
  88. x = self.concat4([x0, x1, x2, x3])
  89. x = self.cbs5(x)
  90. return x
  91. class elan_h(torch.nn.Module): # in_->out_,len->len
  92. def __init__(self, in_, out_, config=None):
  93. super().__init__()
  94. if not config: # 正常版本
  95. self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  96. self.cbs1 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  97. self.cbs2 = cbs(in_ // 2, in_ // 4, kernel_size=3, stride=1)
  98. self.cbs3 = cbs(in_ // 4, in_ // 4, kernel_size=3, stride=1)
  99. self.cbs4 = cbs(in_ // 4, in_ // 4, kernel_size=3, stride=1)
  100. self.cbs5 = cbs(in_ // 4, in_ // 4, kernel_size=3, stride=1)
  101. self.concat6 = concat()
  102. self.cbs7 = cbs(2 * in_, out_, kernel_size=1, stride=1)
  103. else: # 剪枝版本。len(config) = 7
  104. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  105. self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
  106. self.cbs2 = cbs(config[1], config[2], kernel_size=3, stride=1)
  107. self.cbs3 = cbs(config[2], config[3], kernel_size=3, stride=1)
  108. self.cbs4 = cbs(config[3], config[4], kernel_size=3, stride=1)
  109. self.cbs5 = cbs(config[4], config[5], kernel_size=3, stride=1)
  110. self.concat6 = concat()
  111. self.cbs7 = cbs(config[0] + config[1] + config[2] + config[3] + config[4] + config[5], config[6],
  112. kernel_size=1, stride=1)
  113. def forward(self, x):
  114. x0 = self.cbs0(x)
  115. x1 = self.cbs1(x)
  116. x2 = self.cbs2(x1)
  117. x3 = self.cbs3(x2)
  118. x4 = self.cbs4(x3)
  119. x5 = self.cbs5(x4)
  120. x = self.concat6([x0, x1, x2, x3, x4, x5])
  121. x = self.cbs7(x)
  122. return x
  123. class mp(torch.nn.Module): # in_->out_,len->len//2
  124. def __init__(self, in_, out_, config=None):
  125. super().__init__()
  126. if not config: # 正常版本
  127. self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
  128. self.cbs1 = cbs(in_, out_ // 2, 1, 1)
  129. self.cbs2 = cbs(in_, out_ // 2, 1, 1)
  130. self.cbs3 = cbs(out_ // 2, out_ // 2, 3, 2)
  131. self.concat4 = concat(dim=1)
  132. else: # 剪枝版本。len(config) = 3
  133. self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
  134. self.cbs1 = cbs(in_, config[0], 1, 1)
  135. self.cbs2 = cbs(in_, config[1], 1, 1)
  136. self.cbs3 = cbs(config[1], config[2], 3, 2)
  137. self.concat4 = concat(dim=1)
  138. def forward(self, x):
  139. x0 = self.maxpool0(x)
  140. x0 = self.cbs1(x0)
  141. x1 = self.cbs2(x)
  142. x1 = self.cbs3(x1)
  143. x = self.concat4([x0, x1])
  144. return x
  145. class sppf(torch.nn.Module): # in_->out_,len->len
  146. def __init__(self, in_, out_, config=None):
  147. super().__init__()
  148. if not config: # 正常版本
  149. self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  150. self.MaxPool2d1 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
  151. self.MaxPool2d2 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
  152. self.MaxPool2d3 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
  153. self.concat4 = concat(dim=1)
  154. self.cbs5 = cbs(2 * in_, out_, kernel_size=1, stride=1)
  155. else: # 剪枝版本。len(config) = 2
  156. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  157. self.MaxPool2d1 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
  158. self.MaxPool2d2 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
  159. self.MaxPool2d3 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
  160. self.concat4 = concat(dim=1)
  161. self.cbs5 = cbs(4 * config[0], config[1], kernel_size=1, stride=1)
  162. def forward(self, x):
  163. x = self.cbs0(x)
  164. x0 = self.MaxPool2d1(x)
  165. x1 = self.MaxPool2d2(x0)
  166. x2 = self.MaxPool2d3(x1)
  167. x = self.concat4([x, x0, x1, x2])
  168. x = self.cbs5(x)
  169. return x
  170. class sppcspc(torch.nn.Module): # in_->out_,len->len
  171. def __init__(self, in_, out_, config=None):
  172. super().__init__()
  173. if not config: # 正常版本
  174. self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  175. self.cbs1 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  176. self.cbs2 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
  177. self.cbs3 = cbs(in_ // 2, in_ // 2, kernel_size=1, stride=1)
  178. self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
  179. self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
  180. self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
  181. self.concat7 = concat(dim=1)
  182. self.cbs8 = cbs(2 * in_, in_ // 2, kernel_size=1, stride=1)
  183. self.cbs9 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
  184. self.concat10 = concat(dim=1)
  185. self.cbs11 = cbs(in_, out_, kernel_size=1, stride=1)
  186. else: # 剪枝版本。len(config) = 7
  187. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  188. self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
  189. self.cbs2 = cbs(config[1], config[2], kernel_size=3, stride=1)
  190. self.cbs3 = cbs(config[2], config[3], kernel_size=1, stride=1)
  191. self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
  192. self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
  193. self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
  194. self.concat7 = concat(dim=1)
  195. self.cbs8 = cbs(4 * config[3], config[4], kernel_size=1, stride=1)
  196. self.cbs9 = cbs(config[4], config[5], kernel_size=3, stride=1)
  197. self.concat10 = concat(dim=1)
  198. self.cbs11 = cbs(config[0] + config[5], config[6], kernel_size=1, stride=1)
  199. def forward(self, x):
  200. x0 = self.cbs0(x)
  201. x1 = self.cbs1(x)
  202. x1 = self.cbs2(x1)
  203. x1 = self.cbs3(x1)
  204. x4 = self.MaxPool2d4(x1)
  205. x5 = self.MaxPool2d5(x1)
  206. x6 = self.MaxPool2d6(x1)
  207. x = self.concat7([x1, x4, x5, x6])
  208. x = self.cbs8(x)
  209. x = self.cbs9(x)
  210. x = self.concat10([x, x0])
  211. x = self.cbs11(x)
  212. return x
  213. class head(torch.nn.Module): # in_->(batch, 3, output_size, output_size, 5+output_class)),len->len
  214. def __init__(self, in_, output_size, output_class):
  215. super().__init__()
  216. self.output_size = output_size
  217. self.output_class = output_class
  218. self.output = torch.nn.Conv2d(in_, 3 * (5 + output_class), kernel_size=1, stride=1, padding=0)
  219. def forward(self, x):
  220. x = self.output(x).reshape(-1, 3, self.output_size, self.output_size, 5 + self.output_class) # 变形
  221. return x
  222. # 参考yolox
  223. class split_head(torch.nn.Module): # in_->(batch, 3, output_size, output_size, 5+output_class)),len->len
  224. def __init__(self, in_, output_size, output_class, config=None):
  225. super().__init__()
  226. self.output_size = output_size
  227. self.output_class = output_class
  228. if not config: # 正常版本
  229. out_ = 3 * (5 + self.output_class)
  230. self.cbs0 = cbs(in_, out_, kernel_size=1, stride=1)
  231. self.cbs1 = cbs(out_, out_, kernel_size=3, stride=1)
  232. self.cbs2 = cbs(out_, out_, kernel_size=3, stride=1)
  233. self.cbs3 = cbs(out_, out_, kernel_size=3, stride=1)
  234. self.cbs4 = cbs(out_, out_, kernel_size=3, stride=1)
  235. self.Conv2d5 = torch.nn.Conv2d(out_, 12, kernel_size=1, stride=1, padding=0)
  236. self.Conv2d6 = torch.nn.Conv2d(out_, 3, kernel_size=1, stride=1, padding=0)
  237. self.Conv2d7 = torch.nn.Conv2d(out_, 3 * self.output_class, kernel_size=1, stride=1, padding=0)
  238. self.concat8 = concat(4)
  239. else: # 剪枝版本。len(config) = 8
  240. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  241. self.cbs1 = cbs(config[0], config[1], kernel_size=1, stride=1)
  242. self.cbs2 = cbs(config[1], config[2], kernel_size=1, stride=1)
  243. self.cbs3 = cbs(config[0], config[3], kernel_size=1, stride=1)
  244. self.cbs4 = cbs(config[3], config[4], kernel_size=1, stride=1)
  245. self.Conv2d5 = torch.nn.Conv2d(config[5], 12, kernel_size=1, stride=1, padding=0)
  246. self.Conv2d6 = torch.nn.Conv2d(config[6], 3, kernel_size=1, stride=1, padding=0)
  247. self.Conv2d7 = torch.nn.Conv2d(config[7], 3 * self.output_class, kernel_size=1, stride=1, padding=0)
  248. self.concat8 = concat(4)
  249. def forward(self, x):
  250. x = self.cbs0(x)
  251. x0 = self.cbs1(x)
  252. x0 = self.cbs2(x0)
  253. x1 = self.cbs3(x)
  254. x1 = self.cbs4(x1)
  255. x2 = self.Conv2d5(x0).reshape(-1, 3, self.output_size, self.output_size, 4) # 变形
  256. x3 = self.Conv2d6(x0).reshape(-1, 3, self.output_size, self.output_size, 1) # 变形
  257. x4 = self.Conv2d7(x1).reshape(-1, 3, self.output_size, self.output_size, self.output_class) # 变形
  258. x = self.concat8([x2, x3, x4])
  259. return x
  260. class image_deal(torch.nn.Module): # 归一化
  261. def __init__(self):
  262. super().__init__()
  263. def forward(self, x):
  264. x = x / 255
  265. x = x.permute(0, 3, 1, 2)
  266. return x
  267. class decode(torch.nn.Module): # (Cx,Cy,w,h,confidence...)原始输出->(Cx,Cy,w,h,confidence...)真实坐标
  268. def __init__(self, input_size):
  269. super().__init__()
  270. self.stride = (8, 16, 32)
  271. output_size = [int(input_size // i) for i in self.stride]
  272. self.anchor = (((12, 16), (19, 36), (40, 28)), ((36, 75), (76, 55), (72, 146)),
  273. ((142, 110), (192, 243), (459, 401)))
  274. self.grid = [0, 0, 0]
  275. for i in range(3):
  276. self.grid[i] = torch.arange(output_size[i])
  277. self.frame_sigmoid = torch.nn.Sigmoid()
  278. def forward(self, output):
  279. device = output[0].device
  280. # 遍历每一个大层
  281. for i in range(3):
  282. self.grid[i] = self.grid[i].to(device) # 放到对应的设备上
  283. # 中心坐标[0-1]->[-0.5-1.5]->[-0.5*stride-80/40/20.5*stride]
  284. output[i] = self.frame_sigmoid(output[i]) # 边框输出归一化
  285. output[i][..., 0] = (2 * output[i][..., 0] - 0.5 + self.grid[i].unsqueeze(1)) * self.stride[i]
  286. output[i][..., 1] = (2 * output[i][..., 1] - 0.5 + self.grid[i]) * self.stride[i]
  287. # 遍历每一个大层中的小层
  288. for j in range(3):
  289. output[i][:, j, ..., 2] = 4 * output[i][:, j, ..., 2] ** 2 * self.anchor[i][j][0] # [0-1]->[0-4*anchor]
  290. output[i][:, j, ..., 3] = 4 * output[i][:, j, ..., 3] ** 2 * self.anchor[i][j][1] # [0-1]->[0-4*anchor]
  291. return output
  292. class deploy(torch.nn.Module):
  293. def __init__(self, model, input_size):
  294. super().__init__()
  295. self.image_deal = image_deal()
  296. self.model = model
  297. self.decode = decode(input_size)
  298. def forward(self, x):
  299. x = self.image_deal(x)
  300. x = self.model(x)
  301. x = self.decode(x)
  302. return x