123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import os
- import torch
- from torch import nn
- class model_prepare:
- def __init__(self, args):
- self.args = args
- def timm_model(self):
- from model.timm_model import timm_model
- model = timm_model(self.args)
- return model
- def yolov7_cls(self):
- from model.yolov7_cls import yolov7_cls
- model = yolov7_cls(self.args)
- return model
- def model_get(args):
- choice_dict = {
- 'resnet18': model_prepare(args).timm_model,
- 'efficientnetv2_s': model_prepare(args).timm_model,
- 'yolov7_cls': model_prepare(args).yolov7_cls
- }
- print(f"Pruning enabled: {args.prune}")
- if os.path.exists(args.weight):
- print('Loading existing model to continue training...')
- model_dict = torch.load(args.weight, map_location='cpu')
- else:
- if args.prune:
- print("Loading model for pruning...")
- model_dict = torch.load(args.prune_weight, map_location='cpu')
- model = model_dict['model']
- print("Model type before pruning:", type(model))
- model = prune(args, model, choice_dict)
- print("Model type after pruning:", type(model))
- elif args.timm:
- model = model_prepare(args).timm_model()
- else:
- model = choice_dict[args.model]() # ensure it's callable
- model_dict = {'model': model, 'epoch_finished': 0, 'optimizer_state_dict': None, 'ema_updates': 0, 'standard': 0}
- return model_dict
- def prune(args, model, choice_dict):
- if not isinstance(model, nn.Module):
- raise TypeError("Expected model to be a PyTorch model instance")
- BatchNorm2d_weight = [module.weight.data.clone() for module in model.modules() if isinstance(module, nn.BatchNorm2d)]
- BatchNorm2d_weight_abs = torch.cat([w.abs() for w in BatchNorm2d_weight])
- weight_len = len(BatchNorm2d_weight)
- BatchNorm2d_id = [i for i in range(weight_len) for _ in range(len(BatchNorm2d_weight[i]))]
- id_all = torch.tensor(BatchNorm2d_id)
- value, index = torch.sort(BatchNorm2d_weight_abs, descending=True)
- boundary = int(len(index) * args.prune_ratio)
- prune_index = index[:boundary]
- prune_index, _ = torch.sort(prune_index)
- prune_id = id_all[prune_index]
- index_list = [[] for _ in range(weight_len)]
- for i in range(len(prune_index)):
- index_list[prune_id[i]].append(prune_index[i])
- for i, indices in enumerate(index_list):
- if not indices:
- index_list[i] = [torch.argmax(BatchNorm2d_weight[i])]
- index_list[i] = torch.tensor(indices) - sum(len(BatchNorm2d_weight[j]) for j in range(i))
- args.prune_num = [len(x) for x in index_list]
- prune_model = choice_dict[args.model]()
- index = 0
- for module, prune_module in zip(model.modules(), prune_model.modules()):
- if isinstance(module, nn.Conv2d) and index < weight_len:
- if max(index_list[index]) >= module.out_channels:
- raise IndexError("Index out of bounds for Conv2d output channels.")
- weight = module.weight.data.clone()[index_list[index]]
- if index > 0 and max(index_list[index - 1]) < module.in_channels:
- weight = weight[:, index_list[index - 1], :, :]
- prune_module.weight.data = weight
- if isinstance(module, nn.BatchNorm2d):
- if max(index_list[index]) >= module.num_features:
- raise IndexError("Index out of bounds for BatchNorm2d features.")
- prune_module.weight.data = module.weight.data.clone()[index_list[index]]
- prune_module.bias.data = module.bias.data.clone()[index_list[index]]
- prune_module.running_mean = module.running_mean.clone()[index_list[index]]
- prune_module.running_var = module.running_var.clone()[index_list[index]]
- # 打印剪枝后的BatchNorm2d层的参数维度
- if isinstance(prune_model, torch.nn.BatchNorm2d):
- print("Pruned BatchNorm2d:")
- print("Weight:", prune_module.weight.data.shape)
- print("Bias:", prune_module.bias.data.shape)
- print("Running Mean:", prune_module.running_mean.shape)
- print("Running Var:", prune_module.running_var.shape)
- index += 1
- return prune_model
|