prune_utils.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2021/5/24 下午4:36
  3. # @Author : midaskong
  4. # @File : prune_utils.py
  5. # @Description:
  6. import torch
  7. from copy import deepcopy
  8. import numpy as np
  9. import torch.nn.functional as F
  10. def gather_bn_weights(module_list):
  11. prune_idx = list(range(len(module_list)))
  12. size_list = [idx.weight.data.shape[0] for idx in module_list.values()]
  13. bn_weights = torch.zeros(sum(size_list))
  14. index = 0
  15. for i, idx in enumerate(module_list.values()):
  16. size = size_list[i]
  17. bn_weights[index:(index + size)] = idx.weight.data.abs().clone()
  18. index += size
  19. return bn_weights
  20. def gather_conv_weights(module_list):
  21. prune_idx = list(range(len(module_list)))
  22. size_list = [idx.weight.data.shape[0] for idx in module_list.values()]
  23. conv_weights = torch.zeros(sum(size_list))
  24. index = 0
  25. for i, idx in enumerate(module_list.values()):
  26. size = size_list[i]
  27. conv_weights[index:(index + size)] = idx.weight.data.abs().sum(dim=1).sum(dim=1).sum(dim=1).clone()
  28. index += size
  29. return conv_weights
  30. def obtain_bn_mask(bn_module, thre):
  31. thre = thre.cuda()
  32. mask = bn_module.weight.data.abs().ge(thre).float()
  33. return mask
  34. def obtain_conv_mask(conv_module, thre):
  35. thre = thre.cuda()
  36. mask = conv_module.weight.data.abs().sum(dim=1).sum(dim=1).sum(dim=1).ge(thre).float()
  37. return mask
  38. def uodate_pruned_yolov5_cfg(model, maskbndict):
  39. # save pruned yolov5 model in yaml format:
  40. # model:
  41. # model to be pruned
  42. # maskbndict:
  43. # key : module name
  44. # value : bn layer mask index
  45. return