model_get.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. import torch
  3. choice_dict = {
  4. 'yolov7_cls': 'model_prepare(args).yolov7_cls()',
  5. 'timm_model': 'model_prepare(args).timm_model()',
  6. 'Alexnet': 'model_prepare(args).Alexnet()',
  7. 'badnet': 'model_prepare(args).badnet()',
  8. 'GoogleNet': 'model_prepare(args).GoogleNet()',
  9. 'mobilenetv2': 'model_prepare(args).mobilenetv2()',
  10. 'resnet': 'model_prepare(args).resnet()',
  11. 'VGG19': 'model_prepare(args).VGG19()',
  12. 'efficientnet': 'model_prepare(args).EfficientNetV2_S()',
  13. 'LeNet': 'model_prepare(args).LeNet()'
  14. }
  15. def model_get(args):
  16. if os.path.exists(args.weight): # 优先加载已有模型继续训练
  17. model_dict = torch.load(args.weight, map_location='cpu')
  18. else: # 新建模型
  19. if args.prune: # 模型剪枝
  20. model_dict = torch.load(args.prune_weight, map_location='cpu')
  21. model = model_dict['model']
  22. model = prune(args, model)
  23. elif args.timm:
  24. # model = model_prepare(args).timm_model()
  25. model = eval(choice_dict['timm_model'])
  26. else:
  27. model = eval(choice_dict[args.model])
  28. model_dict = {}
  29. model_dict['model'] = model
  30. model_dict['epoch_finished'] = 0 # 已训练的轮数
  31. model_dict['optimizer_state_dict'] = None # 学习率参数
  32. model_dict['ema_updates'] = 0 # ema参数
  33. model_dict['standard'] = 0 # 评价指标
  34. return model_dict
  35. def prune(args, model):
  36. # 记录BN层权重
  37. # Debugging output
  38. BatchNorm2d_weight = []
  39. for module in model.modules():
  40. if isinstance(module, torch.nn.BatchNorm2d):
  41. BatchNorm2d_weight.append(module.weight.data.clone())
  42. BatchNorm2d_weight_abs = torch.cat(BatchNorm2d_weight, dim=0).abs()
  43. weight_len = len(BatchNorm2d_weight)
  44. # 记录权重与BN层编号的关系
  45. BatchNorm2d_id = []
  46. for i in range(weight_len):
  47. BatchNorm2d_id.extend([i for _ in range(len(BatchNorm2d_weight[i]))])
  48. id_all = torch.tensor(BatchNorm2d_id)
  49. # 筛选
  50. value, index = torch.sort(BatchNorm2d_weight_abs, dim=0, descending=True)
  51. boundary = int(len(index) * args.prune_ratio)
  52. prune_index = index[0:boundary] # 保留参数的下标
  53. prune_index, _ = torch.sort(prune_index, dim=0, descending=False)
  54. prune_id = id_all[prune_index]
  55. # 将保留参数的下标放到每层中
  56. index_list = [[] for _ in range(weight_len)]
  57. for i in range(len(prune_index)):
  58. index_list[prune_id[i]].append(prune_index[i])
  59. # 将每层保留参数的下标换算成相对下标
  60. record_len = 0
  61. for i in range(weight_len):
  62. index_list[i] = torch.tensor(index_list[i])
  63. index_list[i] -= record_len
  64. if len(index_list[i]) == 0: # 存在整层都被减去的情况,至少保留一层
  65. index_list[i] = torch.argmax(BatchNorm2d_weight[i], dim=0).unsqueeze(0)
  66. record_len += len(BatchNorm2d_weight[i])
  67. # 创建剪枝后的模型
  68. args.prune_num = [len(_) for _ in index_list]
  69. prune_model = eval(choice_dict[args.model])
  70. # BN层权重赋值和部分conv权重赋值
  71. index = 0
  72. for module, prune_module in zip(model.modules(), prune_model.modules()):
  73. if isinstance(module, torch.nn.Conv2d): # 更新部分Conv2d层权重
  74. print(f"处理 Conv2d 层,索引:{index},权重形状:{module.weight.data.shape}")
  75. if index > 0 and index - 1 < len(index_list):
  76. # 打印 index_list 状态
  77. print(f"当前层前一层索引列表(index_list[{index - 1}]):{index_list[index - 1]}")
  78. # 检查是否索引越界
  79. if index_list[index - 1].max().item() < module.weight.data.shape[1]: # 检查最大索引是否小于输入通道数
  80. weight = module.weight.data.clone()
  81. if index < len(index_list):
  82. weight = weight[:, index_list[index - 1], :, :]
  83. if prune_module.weight.data.shape == weight.shape:
  84. prune_module.weight.data = weight
  85. else:
  86. print("索引越界,跳过当前层的处理")
  87. elif index == 0:
  88. weight = module.weight.data.clone()[index_list[index]]
  89. if prune_module.weight.data.shape == weight.shape:
  90. prune_module.weight.data = weight
  91. if isinstance(module, torch.nn.BatchNorm2d):
  92. print(f"更新 BatchNorm2d 层,索引:{index},权重形状:{module.weight.data.shape}")
  93. if index < len(index_list) and len(index_list[index]) > 0:
  94. expected_size = module.weight.data.size(0)
  95. actual_size = len(index_list[index])
  96. print(f"期望的大小:{expected_size}, 实际保留的大小:{actual_size}")
  97. if actual_size == expected_size:
  98. prune_module.weight.data = module.weight.data.clone()[index_list[index]]
  99. prune_module.bias.data = module.bias.data.clone()[index_list[index]]
  100. prune_module.running_mean = module.running_mean.clone()[index_list[index]]
  101. prune_module.running_var = module.running_var.clone()[index_list[index]]
  102. else:
  103. print("警告: 剪枝后的大小与期望的 BatchNorm2d 层大小不匹配")
  104. index += 1
  105. return prune_model
  106. class model_prepare:
  107. def __init__(self, args):
  108. self.args = args
  109. def timm_model(self):
  110. from model.timm_model import timm_model
  111. model = timm_model(self.args)
  112. return model
  113. def yolov7_cls(self):
  114. from model.yolov7_cls import yolov7_cls
  115. model = yolov7_cls(self.args)
  116. return model
  117. def Alexnet(self):
  118. from model.Alexnet import Alexnet
  119. model = Alexnet(self.args.input_channels, self.args.output_num, self.args.input_size)
  120. return model
  121. def badnet(self):
  122. from model.badnet import BadNet
  123. model = BadNet(self.args.input_channels, self.args.output_num)
  124. return model
  125. def GoogleNet(self):
  126. from model.GoogleNet import GoogLeNet
  127. model = GoogLeNet(self.args.input_channels, self.args.output_num)
  128. return model
  129. def mobilenetv2(self):
  130. from model.mobilenetv2 import MobileNetV2
  131. model = MobileNetV2(self.args.input_channels, self.args.output_num)
  132. return model
  133. def resnet(self):
  134. from model.resnet import ResNet18
  135. model = ResNet18(self.args.input_channels, self.args.output_num)
  136. return model
  137. def VGG19(self):
  138. from model.VGG19 import VGG19
  139. model = VGG19()
  140. return model
  141. def EfficientNetV2_S(self):
  142. from model.efficientnet import EfficientNetV2_S
  143. model = EfficientNetV2_S(self.args.input_channels, self.args.output_num)
  144. return model
  145. def LeNet(self):
  146. from model.LeNet import LeNet
  147. model = LeNet(self.args.input_channels, self.args.output_num, self.args.input_size)
  148. return model