yolov5.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # 根据yolov5改编:https://github.com/ultralytics/yolov5
  2. import torch
  3. from torch import nn
  4. from model.layer import cbs, c3, sppf, concat, head
  5. class yolov5(torch.nn.Module):
  6. def __init__(self, args):
  7. super().__init__()
  8. dim_dict = {'n': 8, 's': 16, 'm': 32, 'l': 64}
  9. n_dict = {'n': 1, 's': 1, 'm': 2, 'l': 3}
  10. dim = dim_dict[args.model_type]
  11. n = n_dict[args.model_type]
  12. input_size = args.input_size
  13. stride = (8, 16, 32)
  14. self.output_size = [int(input_size // i) for i in stride] # 每个输出层的尺寸,如(80,40,20)
  15. self.output_class = args.output_class
  16. # 网络结构
  17. self.l0 = cbs(3, dim, 6, 2) # 1/2
  18. self.l1 = cbs(dim, 2 * dim, 3, 2) # 1/4
  19. # ---------- #
  20. self.l2 = c3(2 * dim, 2 * dim, n)
  21. self.l3 = cbs(2 * dim, 4 * dim, 3, 2) # 1/8
  22. self.l4 = c3(4 * dim, 4 * dim, 2 * n)
  23. self.l5 = cbs(4 * dim, 8 * dim, 3, 2) # 1/16
  24. self.l6 = c3(8 * dim, 8 * dim, 3 * n)
  25. self.l7 = cbs(8 * dim, 16 * dim, 3, 2) # 1/32
  26. self.l8 = c3(16 * dim, 16 * dim, n)
  27. self.l9 = sppf(16 * dim, 16 * dim)
  28. self.l10 = cbs(16 * dim, 8 * dim, 1, 1)
  29. # ---------- #
  30. self.l11 = torch.nn.Upsample(scale_factor=2) # 1/16
  31. self.l12 = concat(1)
  32. self.l13 = c3(16 * dim, 8 * dim, n)
  33. self.l14 = cbs(8 * dim, 4 * dim, 1, 1)
  34. # ---------- #
  35. self.l15 = torch.nn.Upsample(scale_factor=2) # 1/8
  36. self.l16 = concat(1)
  37. self.l17 = c3(8 * dim, 4 * dim, n) # 接output0
  38. # ---------- #
  39. self.l18 = cbs(4 * dim, 4 * dim, 3, 2) # 1/16
  40. self.l19 = concat(1)
  41. self.l20 = c3(8 * dim, 8 * dim, n) # 接output1
  42. # ---------- #
  43. self.l21 = cbs(8 * dim, 8 * dim, 3, 2) # 1/32
  44. self.l22 = concat(1)
  45. self.l23 = c3(16 * dim, 16 * dim, n) # 接output2
  46. # ---------- #
  47. self.output0 = head(4 * dim, self.output_size[0], self.output_class)
  48. self.output1 = head(8 * dim, self.output_size[1], self.output_class)
  49. self.output2 = head(16 * dim, self.output_size[2], self.output_class)
  50. def forward(self, x):
  51. x = self.l0(x)
  52. x = self.l1(x)
  53. x = self.l2(x)
  54. x = self.l3(x)
  55. l4 = self.l4(x)
  56. x = self.l5(l4)
  57. l6 = self.l6(x)
  58. x = self.l7(l6)
  59. x = self.l8(x)
  60. x = self.l9(x)
  61. l10 = self.l10(x)
  62. x = self.l11(l10)
  63. x = self.l12([x, l6])
  64. x = self.l13(x)
  65. l14 = self.l14(x)
  66. x = self.l15(l14)
  67. x = self.l16([x, l4])
  68. x = self.l17(x)
  69. output0 = self.output0(x)
  70. x = self.l18(x)
  71. x = self.l19([x, l14])
  72. x = self.l20(x)
  73. output1 = self.output1(x)
  74. x = self.l21(x)
  75. x = self.l22([x, l10])
  76. x = self.l23(x)
  77. output2 = self.output2(x)
  78. return [output0, output1, output2]
  79. def get_encode_layers(self):
  80. """
  81. 获取用于白盒模型水印加密层,每个模型根据复杂度选择合适的卷积层
  82. """
  83. conv_list = []
  84. for module in self.modules():
  85. if isinstance(module, nn.Conv2d) and module.out_channels > 100:
  86. conv_list.append(module)
  87. return conv_list[0:2]
  88. if __name__ == '__main__':
  89. import argparse
  90. parser = argparse.ArgumentParser(description='')
  91. parser.add_argument('--prune', default=False, type=bool)
  92. parser.add_argument('--model_type', default='n', type=str)
  93. parser.add_argument('--input_size', default=640, type=int)
  94. parser.add_argument('--output_class', default=1, type=int)
  95. args = parser.parse_args()
  96. model = yolov5(args)
  97. tensor = torch.rand(2, 3, args.input_size, args.input_size, dtype=torch.float32)
  98. pred = model(tensor)
  99. print(model)
  100. print(pred[0].shape, pred[1].shape, pred[2].shape)