import os import torch from torch import nn class model_prepare: def __init__(self, args): self.args = args def timm_model(self): from model.timm_model import timm_model model = timm_model(self.args) return model def yolov7_cls(self): from model.yolov7_cls import yolov7_cls model = yolov7_cls(self.args) return model def model_get(args): choice_dict = { 'resnet18': model_prepare(args).timm_model, 'efficientnetv2_s': model_prepare(args).timm_model, 'yolov7_cls': model_prepare(args).yolov7_cls } print(f"Pruning enabled: {args.prune}") if os.path.exists(args.weight): print('Loading existing model to continue training...') model_dict = torch.load(args.weight, map_location='cpu') else: if args.prune: print("Loading model for pruning...") model_dict = torch.load(args.prune_weight, map_location='cpu') model = model_dict['model'] print("Model type before pruning:", type(model)) model = prune(args, model, choice_dict) print("Model type after pruning:", type(model)) elif args.timm: model = model_prepare(args).timm_model() else: model = choice_dict[args.model]() # ensure it's callable model_dict = {'model': model, 'epoch_finished': 0, 'optimizer_state_dict': None, 'ema_updates': 0, 'standard': 0} return model_dict def prune(args, model, choice_dict): if not isinstance(model, nn.Module): raise TypeError("Expected model to be a PyTorch model instance") BatchNorm2d_weight = [module.weight.data.clone() for module in model.modules() if isinstance(module, nn.BatchNorm2d)] BatchNorm2d_weight_abs = torch.cat([w.abs() for w in BatchNorm2d_weight]) weight_len = len(BatchNorm2d_weight) BatchNorm2d_id = [i for i in range(weight_len) for _ in range(len(BatchNorm2d_weight[i]))] id_all = torch.tensor(BatchNorm2d_id) value, index = torch.sort(BatchNorm2d_weight_abs, descending=True) boundary = int(len(index) * args.prune_ratio) prune_index = index[:boundary] prune_index, _ = torch.sort(prune_index) prune_id = id_all[prune_index] index_list = [[] for _ in range(weight_len)] for i in range(len(prune_index)): index_list[prune_id[i]].append(prune_index[i]) for i, indices in enumerate(index_list): if not indices: index_list[i] = [torch.argmax(BatchNorm2d_weight[i])] index_list[i] = torch.tensor(indices) - sum(len(BatchNorm2d_weight[j]) for j in range(i)) args.prune_num = [len(x) for x in index_list] prune_model = choice_dict[args.model]() index = 0 for module, prune_module in zip(model.modules(), prune_model.modules()): if isinstance(module, nn.Conv2d) and index < weight_len: if max(index_list[index]) >= module.out_channels: raise IndexError("Index out of bounds for Conv2d output channels.") weight = module.weight.data.clone()[index_list[index]] if index > 0 and max(index_list[index - 1]) < module.in_channels: weight = weight[:, index_list[index - 1], :, :] prune_module.weight.data = weight if isinstance(module, nn.BatchNorm2d): if max(index_list[index]) >= module.num_features: raise IndexError("Index out of bounds for BatchNorm2d features.") prune_module.weight.data = module.weight.data.clone()[index_list[index]] prune_module.bias.data = module.bias.data.clone()[index_list[index]] prune_module.running_mean = module.running_mean.clone()[index_list[index]] prune_module.running_var = module.running_var.clone()[index_list[index]] # 打印剪枝后的BatchNorm2d层的参数维度 if isinstance(prune_model, torch.nn.BatchNorm2d): print("Pruned BatchNorm2d:") print("Weight:", prune_module.weight.data.shape) print("Bias:", prune_module.bias.data.shape) print("Running Mean:", prune_module.running_mean.shape) print("Running Var:", prune_module.running_var.shape) index += 1 return prune_model