123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- # -*- coding: utf-8 -*-
- # @Time : 2021/5/24 下午4:36
- # @Author : midaskong
- # @File : prune_utils.py
- # @Description:
- import torch
- from copy import deepcopy
- import numpy as np
- import torch.nn.functional as F
- def gather_bn_weights(module_list):
- prune_idx = list(range(len(module_list)))
- size_list = [idx.weight.data.shape[0] for idx in module_list.values()]
- bn_weights = torch.zeros(sum(size_list))
- index = 0
- for i, idx in enumerate(module_list.values()):
- size = size_list[i]
- bn_weights[index:(index + size)] = idx.weight.data.abs().clone()
- index += size
- return bn_weights
- def gather_conv_weights(module_list):
- prune_idx = list(range(len(module_list)))
- size_list = [idx.weight.data.shape[0] for idx in module_list.values()]
- conv_weights = torch.zeros(sum(size_list))
- index = 0
- for i, idx in enumerate(module_list.values()):
- size = size_list[i]
- conv_weights[index:(index + size)] = idx.weight.data.abs().sum(dim=1).sum(dim=1).sum(dim=1).clone()
- index += size
- return conv_weights
- def obtain_bn_mask(bn_module, thre):
- thre = thre.cuda()
- mask = bn_module.weight.data.abs().ge(thre).float()
- return mask
- def obtain_conv_mask(conv_module, thre):
- thre = thre.cuda()
- mask = conv_module.weight.data.abs().sum(dim=1).sum(dim=1).sum(dim=1).ge(thre).float()
- return mask
- def uodate_pruned_yolov5_cfg(model, maskbndict):
- # save pruned yolov5 model in yaml format:
- # model:
- # model to be pruned
- # maskbndict:
- # key : module name
- # value : bn layer mask index
- return
|