layer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import torch
  2. import sys
  3. sys.path.append('/home/yhsun/classification-main/')
  4. class cbs(torch.nn.Module):
  5. def __init__(self, in_, out_, kernel_size, stride):
  6. super().__init__()
  7. self.conv = torch.nn.Conv2d(in_, out_, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
  8. bias=False)
  9. self.bn = torch.nn.BatchNorm2d(out_, eps=0.001, momentum=0.03)
  10. self.silu = torch.nn.SiLU(inplace=True)
  11. def forward(self, x):
  12. x = self.conv(x)
  13. x = self.bn(x)
  14. x = self.silu(x)
  15. return x
  16. class concat(torch.nn.Module):
  17. def __init__(self, dim=1):
  18. super().__init__()
  19. self.concat = torch.cat
  20. self.dim = dim
  21. def forward(self, x):
  22. x = self.concat(x, dim=self.dim)
  23. return x
  24. class elan(torch.nn.Module): # in_->out_,len->len
  25. def __init__(self, in_, out_, n, config=None):
  26. super().__init__()
  27. if not config: # 正常版本
  28. self.cbs0 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
  29. self.cbs1 = cbs(in_, out_ // 4, kernel_size=1, stride=1)
  30. self.sequential2 = torch.nn.Sequential(
  31. *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
  32. self.sequential3 = torch.nn.Sequential(
  33. *(cbs(out_ // 4, out_ // 4, kernel_size=3, stride=1) for _ in range(n)))
  34. self.concat4 = concat()
  35. self.cbs5 = cbs(out_, out_, kernel_size=1, stride=1)
  36. else: # 剪枝版本。len(config) = 3 + 2 * n
  37. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  38. self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
  39. self.sequential2 = torch.nn.Sequential(
  40. *(cbs(config[1 + _], config[2 + _], kernel_size=3, stride=1) for _ in range(n)))
  41. self.sequential3 = torch.nn.Sequential(
  42. *(cbs(config[1 + n + _], config[2 + n + _], kernel_size=3, stride=1) for _ in range(n)))
  43. self.concat4 = concat()
  44. self.cbs5 = cbs(config[0] + config[1] + config[1 + n] + config[1 + 2 * n], config[2 + 2 * n],
  45. kernel_size=1, stride=1)
  46. def forward(self, x):
  47. x0 = self.cbs0(x)
  48. x1 = self.cbs1(x)
  49. x2 = self.sequential2(x1)
  50. x3 = self.sequential3(x2)
  51. x = self.concat4([x0, x1, x2, x3])
  52. x = self.cbs5(x)
  53. return x
  54. class mp(torch.nn.Module): # in_->out_,len->len//2
  55. def __init__(self, in_, out_, config=None):
  56. super().__init__()
  57. if not config: # 正常版本
  58. self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
  59. self.cbs1 = cbs(in_, out_ // 2, 1, 1)
  60. self.cbs2 = cbs(in_, out_ // 2, 1, 1)
  61. self.cbs3 = cbs(out_ // 2, out_ // 2, 3, 2)
  62. self.concat4 = concat(dim=1)
  63. else: # 剪枝版本。len(config) = 3
  64. self.maxpool0 = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1)
  65. self.cbs1 = cbs(in_, config[0], 1, 1)
  66. self.cbs2 = cbs(in_, config[1], 1, 1)
  67. self.cbs3 = cbs(config[1], config[2], 3, 2)
  68. self.concat4 = concat(dim=1)
  69. def forward(self, x):
  70. x0 = self.maxpool0(x)
  71. x0 = self.cbs1(x0)
  72. x1 = self.cbs2(x)
  73. x1 = self.cbs3(x1)
  74. x = self.concat4([x0, x1])
  75. return x
  76. class sppcspc(torch.nn.Module): # in_->out_,len->len
  77. def __init__(self, in_, out_, config=None):
  78. super().__init__()
  79. if not config: # 正常版本
  80. self.cbs0 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  81. self.cbs1 = cbs(in_, in_ // 2, kernel_size=1, stride=1)
  82. self.cbs2 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
  83. self.cbs3 = cbs(in_ // 2, in_ // 2, kernel_size=1, stride=1)
  84. self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
  85. self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
  86. self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
  87. self.concat7 = concat(dim=1)
  88. self.cbs8 = cbs(2 * in_, in_ // 2, kernel_size=1, stride=1)
  89. self.cbs9 = cbs(in_ // 2, in_ // 2, kernel_size=3, stride=1)
  90. self.concat10 = concat(dim=1)
  91. self.cbs11 = cbs(in_, out_, kernel_size=1, stride=1)
  92. else: # 剪枝版本。len(config) = 7
  93. self.cbs0 = cbs(in_, config[0], kernel_size=1, stride=1)
  94. self.cbs1 = cbs(in_, config[1], kernel_size=1, stride=1)
  95. self.cbs2 = cbs(config[1], config[2], kernel_size=3, stride=1)
  96. self.cbs3 = cbs(config[2], config[3], kernel_size=1, stride=1)
  97. self.MaxPool2d4 = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1)
  98. self.MaxPool2d5 = torch.nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1)
  99. self.MaxPool2d6 = torch.nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1)
  100. self.concat7 = concat(dim=1)
  101. self.cbs8 = cbs(4 * config[3], config[4], kernel_size=1, stride=1)
  102. self.cbs9 = cbs(config[4], config[5], kernel_size=3, stride=1)
  103. self.concat10 = concat(dim=1)
  104. self.cbs11 = cbs(config[0] + config[5], config[6], kernel_size=1, stride=1)
  105. def forward(self, x):
  106. x0 = self.cbs0(x)
  107. x1 = self.cbs1(x)
  108. x1 = self.cbs2(x1)
  109. x1 = self.cbs3(x1)
  110. x4 = self.MaxPool2d4(x1)
  111. x5 = self.MaxPool2d5(x1)
  112. x6 = self.MaxPool2d6(x1)
  113. x = self.concat7([x1, x4, x5, x6])
  114. x = self.cbs8(x)
  115. x = self.cbs9(x)
  116. x = self.concat10([x, x0])
  117. x = self.cbs11(x)
  118. return x
  119. class linear_head(torch.nn.Module):
  120. def __init__(self, in_, out_):
  121. super().__init__()
  122. self.avgpool0 = torch.nn.AdaptiveAvgPool2d(1)
  123. self.flatten1 = torch.nn.Flatten()
  124. self.Dropout2 = torch.nn.Dropout(0.2)
  125. self.linear3 = torch.nn.Linear(in_, in_ // 2)
  126. self.silu4 = torch.nn.SiLU()
  127. self.Dropout5 = torch.nn.Dropout(0.2)
  128. self.linear6 = torch.nn.Linear(in_ // 2, out_)
  129. def forward(self, x):
  130. x = self.avgpool0(x)
  131. x = self.flatten1(x)
  132. x = self.Dropout2(x)
  133. x = self.linear3(x)
  134. x = self.silu4(x)
  135. x = self.Dropout5(x)
  136. x = self.linear6(x)
  137. return x
  138. class image_deal(torch.nn.Module): # 归一化
  139. def __init__(self):
  140. super().__init__()
  141. def forward(self, x):
  142. x = x / 255
  143. x = x.permute(0, 3, 1, 2)
  144. return x
  145. class deploy(torch.nn.Module):
  146. def __init__(self, model, normalization):
  147. super().__init__()
  148. self.image_deal = image_deal()
  149. self.model = model
  150. if normalization == 'softmax':
  151. self.normalization = torch.nn.Softmax(dim=1)
  152. else:
  153. self.normalization = torch.nn.Sigmoid()
  154. def forward(self, x):
  155. x = self.image_deal(x)
  156. x = self.model(x)
  157. x = self.normalization(x)
  158. return x