yolov7.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # 根据yolov7改编:https://github.com/WongKinYiu/yolov7
  2. import torch
  3. from model.layer import cbs, elan, elan_h, mp, sppcspc, concat, head
  4. class yolov7(torch.nn.Module):
  5. def __init__(self, args):
  6. super().__init__()
  7. dim_dict = {'n': 8, 's': 16, 'm': 32, 'l': 64}
  8. n_dict = {'n': 1, 's': 1, 'm': 2, 'l': 3}
  9. dim = dim_dict[args.model_type]
  10. n = n_dict[args.model_type]
  11. input_size = args.input_size
  12. stride = (8, 16, 32)
  13. self.output_size = [int(input_size // i) for i in stride] # 每个输出层的尺寸,如(80,40,20)
  14. self.output_class = args.output_class
  15. # 网络结构
  16. if not args.prune: # 正常版本
  17. self.l0 = cbs(3, dim, 3, 1)
  18. self.l1 = cbs(dim, 2 * dim, 3, 2) # input_size/2
  19. self.l2 = cbs(2 * dim, 2 * dim, 3, 1)
  20. self.l3 = cbs(2 * dim, 4 * dim, 3, 2) # input_size/4
  21. # ---------- #
  22. self.l4 = elan(4 * dim, 8 * dim, n)
  23. self.l5 = mp(8 * dim, 8 * dim) # input_size/8
  24. self.l6 = elan(8 * dim, 16 * dim, n)
  25. self.l7 = mp(16 * dim, 16 * dim) # input_size/16
  26. self.l8 = elan(16 * dim, 32 * dim, n)
  27. self.l9 = mp(32 * dim, 32 * dim) # input_size/32
  28. self.l10 = elan(32 * dim, 32 * dim, n)
  29. self.l11 = sppcspc(32 * dim, 16 * dim)
  30. self.l12 = cbs(16 * dim, 8 * dim, 1, 1)
  31. # ---------- #
  32. self.l13 = torch.nn.Upsample(scale_factor=2) # input_size/16
  33. self.l8_add = cbs(32 * dim, 8 * dim, 1, 1)
  34. self.l14 = concat(1)
  35. self.l15 = elan_h(16 * dim, 8 * dim)
  36. self.l16 = cbs(8 * dim, 4 * dim, 1, 1)
  37. # ---------- #
  38. self.l17 = torch.nn.Upsample(scale_factor=2) # input_size/8
  39. self.l6_add = cbs(16 * dim, 4 * dim, 1, 1)
  40. self.l18 = concat(1)
  41. self.l19 = elan_h(8 * dim, 4 * dim) # 接output0
  42. # ---------- #
  43. self.l20 = mp(4 * dim, 8 * dim)
  44. self.l21 = concat(1)
  45. self.l22 = elan_h(16 * dim, 8 * dim) # 接output1
  46. # ---------- #
  47. self.l23 = mp(8 * dim, 16 * dim)
  48. self.l24 = concat(1)
  49. self.l25 = elan_h(32 * dim, 16 * dim) # 接output2
  50. # ---------- #
  51. self.output0 = head(4 * dim, self.output_size[0], self.output_class)
  52. self.output1 = head(8 * dim, self.output_size[1], self.output_class)
  53. self.output2 = head(16 * dim, self.output_size[2], self.output_class)
  54. else: # 剪枝版本
  55. config = args.prune_num
  56. self.l0 = cbs(3, config[0], 1, 1)
  57. self.l1 = cbs(config[0], config[1], 3, 2) # input_size/2
  58. self.l2 = cbs(config[1], config[2], 1, 1)
  59. self.l3 = cbs(config[2], config[3], 3, 2) # input_size/4
  60. # ---------- #
  61. self.l4 = elan(config[3], None, n, config[4:7 + 2 * n])
  62. self.l5 = mp(config[6 + 2 * n], None, config[7 + 2 * n:10 + 2 * n]) # input_size/8
  63. self.l6 = elan(config[7 + 2 * n] + config[9 + 2 * n], None, n, config[10 + 2 * n:13 + 4 * n])
  64. self.l7 = mp(config[12 + 4 * n], None, config[13 + 4 * n:16 + 4 * n]) # input_size/16
  65. self.l8 = elan(config[13 + 4 * n] + config[15 + 4 * n], None, n, config[16 + 4 * n:19 + 6 * n])
  66. self.l9 = mp(config[18 + 6 * n], None, config[19 + 6 * n:22 + 6 * n]) # input_size/32
  67. self.l10 = elan(config[19 + 6 * n] + config[21 + 6 * n], None, n, config[22 + 6 * n:25 + 8 * n])
  68. self.l11 = sppcspc(config[24 + 8 * n], None, config[25 + 8 * n:32 + 8 * n])
  69. self.l12 = cbs(config[31 + 8 * n], config[32 + 8 * n], 1, 1)
  70. # ---------- #
  71. self.l13 = torch.nn.Upsample(scale_factor=2) # input_size/16
  72. self.l8_add = cbs(config[18 + 6 * n], config[33 + 8 * n], 1, 1)
  73. self.l14 = concat(1)
  74. self.l15 = elan_h(config[32 + 8 * n] + config[33 + 8 * n], None, config[34 + 8 * n:41 + 8 * n])
  75. self.l16 = cbs(config[40 + 8 * n], config[41 + 8 * n], 1, 1)
  76. # ---------- #
  77. self.l17 = torch.nn.Upsample(scale_factor=2) # input_size/8
  78. self.l6_add = cbs(config[12 + 4 * n], config[42 + 8 * n], 1, 1)
  79. self.l18 = concat(1)
  80. self.l19 = elan_h(config[41 + 8 * n] + config[42 + 8 * n], None, config[43 + 8 * n:50 + 8 * n]) # 接output0
  81. # ---------- #
  82. self.l20 = mp(config[49 + 8 * n], None, config[50 + 8 * n:53 + 8 * n])
  83. self.l21 = concat(1)
  84. self.l22 = elan_h(config[40 + 8 * n] + config[50 + 8 * n] + config[52 + 8 * n], None,
  85. config[53 + 8 * n:60 + 8 * n]) # 接output1
  86. # ---------- #
  87. self.l23 = mp(config[59 + 8 * n], None, config[60 + 8 * n:63 + 8 * n])
  88. self.l24 = concat(1)
  89. self.l25 = elan_h(config[31 + 8 * n] + config[60 + 8 * n] + config[62 + 8 * n], None,
  90. config[63 + 8 * n:70 + 8 * n]) # 接output2
  91. # ---------- #
  92. self.output0 = head(config[49 + 8 * n], self.output_size[0], self.output_class)
  93. self.output1 = head(config[59 + 8 * n], self.output_size[1], self.output_class)
  94. self.output2 = head(config[69 + 8 * n], self.output_size[2], self.output_class)
  95. def forward(self, x):
  96. x = self.l0(x)
  97. x = self.l1(x)
  98. x = self.l2(x)
  99. x = self.l3(x)
  100. x = self.l4(x)
  101. x = self.l5(x)
  102. l6 = self.l6(x)
  103. x = self.l7(l6)
  104. l8 = self.l8(x)
  105. x = self.l9(l8)
  106. x = self.l10(x)
  107. l11 = self.l11(x)
  108. x = self.l12(l11)
  109. x = self.l13(x)
  110. l8_add = self.l8_add(l8)
  111. x = self.l14([x, l8_add])
  112. l15 = self.l15(x)
  113. x = self.l16(l15)
  114. x = self.l17(x)
  115. l6_add = self.l6_add(l6)
  116. x = self.l18([x, l6_add])
  117. x = self.l19(x)
  118. output0 = self.output0(x)
  119. x = self.l20(x)
  120. x = self.l21([x, l15])
  121. x = self.l22(x)
  122. output1 = self.output1(x)
  123. x = self.l23(x)
  124. x = self.l24([x, l11])
  125. x = self.l25(x)
  126. output2 = self.output2(x)
  127. return [output0, output1, output2]
  128. if __name__ == '__main__':
  129. import argparse
  130. parser = argparse.ArgumentParser(description='')
  131. parser.add_argument('--prune', default=False, type=bool)
  132. parser.add_argument('--model_type', default='n', type=str)
  133. parser.add_argument('--input_size', default=640, type=int)
  134. parser.add_argument('--output_class', default=1, type=int)
  135. args = parser.parse_args()
  136. model = yolov7(args)
  137. tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
  138. pred = model(tensor)
  139. print(model)
  140. print(pred[0].shape, pred[1].shape, pred[2].shape)