123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import os
- import torch
- choice_dict = {
- 'LeNet': 'model_prepare(args).LeNet()',
- 'Alexnet': 'model_prepare(args).Alexnet()',
- 'VGG19': 'model_prepare(args).VGG19()',
- 'VGG16': 'model_prepare(args).VGG16()',
- 'GoogleNet': 'model_prepare(args).GoogleNet()',
- 'resnet': 'model_prepare(args).resnet()'
- }
- def model_get(args):
- if args.weight and os.path.exists(args.weight): # 优先加载已有模型继续训练
- model_dict = torch.load(args.weight, map_location='cpu')
- else: # 新建模型
- if args.prune: # 模型剪枝
- model_dict = torch.load(args.prune_weight, map_location='cpu')
- model = model_dict['model']
- model = prune(args, model)
- else:
- model = eval(choice_dict[args.model])
- model_dict = {}
- model_dict['model'] = model
- model_dict['epoch_finished'] = 0 # 已训练的轮数
- model_dict['optimizer_state_dict'] = None # 学习率参数
- model_dict['ema_updates'] = 0 # ema参数
- model_dict['standard'] = 0 # 评价指标
- return model_dict
- def prune(args, model):
- # 记录BN层权重
- # Debugging output
- BatchNorm2d_weight = []
- for module in model.modules():
- if isinstance(module, torch.nn.BatchNorm2d):
- BatchNorm2d_weight.append(module.weight.data.clone())
- BatchNorm2d_weight_abs = torch.cat(BatchNorm2d_weight, dim=0).abs()
- weight_len = len(BatchNorm2d_weight)
- # 记录权重与BN层编号的关系
- BatchNorm2d_id = []
- for i in range(weight_len):
- BatchNorm2d_id.extend([i for _ in range(len(BatchNorm2d_weight[i]))])
- id_all = torch.tensor(BatchNorm2d_id)
- # 筛选
- value, index = torch.sort(BatchNorm2d_weight_abs, dim=0, descending=True)
- boundary = int(len(index) * args.prune_ratio)
- prune_index = index[0:boundary] # 保留参数的下标
- prune_index, _ = torch.sort(prune_index, dim=0, descending=False)
- prune_id = id_all[prune_index]
- # 将保留参数的下标放到每层中
- index_list = [[] for _ in range(weight_len)]
- for i in range(len(prune_index)):
- index_list[prune_id[i]].append(prune_index[i])
- # 将每层保留参数的下标换算成相对下标
- record_len = 0
- for i in range(weight_len):
- index_list[i] = torch.tensor(index_list[i])
- index_list[i] -= record_len
- if len(index_list[i]) == 0: # 存在整层都被减去的情况,至少保留一层
- index_list[i] = torch.argmax(BatchNorm2d_weight[i], dim=0).unsqueeze(0)
- record_len += len(BatchNorm2d_weight[i])
- # 创建剪枝后的模型
- args.prune_num = [len(_) for _ in index_list]
- prune_model = eval(choice_dict[args.model])
- # BN层权重赋值和部分conv权重赋值
- index = 0
- for module, prune_module in zip(model.modules(), prune_model.modules()):
- if isinstance(module, torch.nn.Conv2d): # 更新部分Conv2d层权重
- print(f"处理 Conv2d 层,索引:{index},权重形状:{module.weight.data.shape}")
- if index > 0 and index - 1 < len(index_list):
- # 打印 index_list 状态
- print(f"当前层前一层索引列表(index_list[{index - 1}]):{index_list[index - 1]}")
- # 检查是否索引越界
- if index_list[index - 1].max().item() < module.weight.data.shape[1]: # 检查最大索引是否小于输入通道数
- weight = module.weight.data.clone()
- if index < len(index_list):
- weight = weight[:, index_list[index - 1], :, :]
- if prune_module.weight.data.shape == weight.shape:
- prune_module.weight.data = weight
- else:
- print("索引越界,跳过当前层的处理")
- elif index == 0:
- weight = module.weight.data.clone()[index_list[index]]
- if prune_module.weight.data.shape == weight.shape:
- prune_module.weight.data = weight
- if isinstance(module, torch.nn.BatchNorm2d):
- print(f"更新 BatchNorm2d 层,索引:{index},权重形状:{module.weight.data.shape}")
- if index < len(index_list) and len(index_list[index]) > 0:
- expected_size = module.weight.data.size(0)
- actual_size = len(index_list[index])
- print(f"期望的大小:{expected_size}, 实际保留的大小:{actual_size}")
- if actual_size == expected_size:
- prune_module.weight.data = module.weight.data.clone()[index_list[index]]
- prune_module.bias.data = module.bias.data.clone()[index_list[index]]
- prune_module.running_mean = module.running_mean.clone()[index_list[index]]
- prune_module.running_var = module.running_var.clone()[index_list[index]]
- else:
- print("警告: 剪枝后的大小与期望的 BatchNorm2d 层大小不匹配")
- index += 1
- return prune_model
- class model_prepare:
- def __init__(self, args):
- self.args = args
- def LeNet(self):
- from model.LeNet import LeNet
- model = LeNet(self.args.input_channels, self.args.output_num, self.args.input_size)
- return model
- def Alexnet(self):
- from model.Alexnet import Alexnet
- model = Alexnet(self.args.input_channels, self.args.output_num, self.args.input_size)
- return model
- def VGG19(self):
- from model.VGG19 import VGG19
- model = VGG19()
- return model
- def VGG16(self):
- from model.VGG19 import VGG16
- model = VGG16(self.args.input_size)
- return model
- def GoogleNet(self):
- from model.GoogleNet import GoogLeNet
- model = GoogLeNet(self.args.input_channels, self.args.output_num)
- return model
- def resnet(self):
- from model.resnet import ResNet18
- model = ResNet18(self.args.input_channels, self.args.output_num)
- return model
|