pruned_common.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import json
  6. import math
  7. import platform
  8. import warnings
  9. from collections import OrderedDict, namedtuple
  10. from copy import copy
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import requests
  16. import torch
  17. import torch.nn as nn
  18. from PIL import Image
  19. from torch.cuda import amp
  20. from utils.datasets import exif_transpose, letterbox
  21. from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
  22. make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
  23. from utils.plots import Annotator, colors, save_one_box
  24. from utils.torch_utils import copy_attr, time_sync
  25. from models.common import Conv
  26. class BottleneckPruned(nn.Module):
  27. # Pruned bottleneck
  28. def __init__(self, cv1in, cv1out, cv2out, shortcut=True, g=1): # ch_in, ch_out, shortcut, groups, expansion
  29. super(BottleneckPruned, self).__init__()
  30. self.cv1 = Conv(cv1in, cv1out, 1, 1)
  31. self.cv2 = Conv(cv1out, cv2out, 3, 1, g=g)
  32. self.add = shortcut and cv1in == cv2out
  33. def forward(self, x):
  34. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  35. class C3Pruned(nn.Module):
  36. # CSP Bottleneck with 3 convolutions
  37. def __init__(self, cv1in, cv1out, cv2out, cv3out, bottle_args, n=1, shortcut=True, g=1): # ch_in, ch_out, number, shortcut, groups, expansion
  38. super(C3Pruned, self).__init__()
  39. cv3in = bottle_args[-1][-1]
  40. self.cv1 = Conv(cv1in, cv1out, 1, 1)
  41. self.cv2 = Conv(cv1in, cv2out, 1, 1)
  42. self.cv3 = Conv(cv3in+cv2out, cv3out, 1)
  43. self.m = nn.Sequential(*[BottleneckPruned(*bottle_args[k], shortcut, g) for k in range(n)])
  44. def forward(self, x):
  45. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  46. class SPPFPruned(nn.Module):
  47. # Spatial pyramid pooling layer used in YOLOv3-SPP
  48. def __init__(self, cv1in, cv1out, cv2out, k=5):
  49. super(SPPFPruned, self).__init__()
  50. self.cv1 = Conv(cv1in, cv1out, 1, 1)
  51. self.cv2 = Conv(cv1out * 4, cv2out, 1, 1)
  52. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  53. def forward(self, x):
  54. x = self.cv1(x)
  55. with warnings.catch_warnings():
  56. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  57. y1 = self.m(x)
  58. y2 = self.m(y1)
  59. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))