test_model_get.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import os
  2. import torch
  3. from torch import nn
  4. class model_prepare:
  5. def __init__(self, args):
  6. self.args = args
  7. def timm_model(self):
  8. from model.timm_model import timm_model
  9. model = timm_model(self.args)
  10. return model
  11. def yolov7_cls(self):
  12. from model.yolov7_cls import yolov7_cls
  13. model = yolov7_cls(self.args)
  14. return model
  15. def model_get(args):
  16. choice_dict = {
  17. 'resnet18': model_prepare(args).timm_model,
  18. 'efficientnetv2_s': model_prepare(args).timm_model,
  19. 'yolov7_cls': model_prepare(args).yolov7_cls
  20. }
  21. print(f"Pruning enabled: {args.prune}")
  22. if os.path.exists(args.weight):
  23. print('Loading existing model to continue training...')
  24. model_dict = torch.load(args.weight, map_location='cpu')
  25. else:
  26. if args.prune:
  27. print("Loading model for pruning...")
  28. model_dict = torch.load(args.prune_weight, map_location='cpu')
  29. model = model_dict['model']
  30. print("Model type before pruning:", type(model))
  31. model = prune(args, model, choice_dict)
  32. print("Model type after pruning:", type(model))
  33. elif args.timm:
  34. model = model_prepare(args).timm_model()
  35. else:
  36. model = choice_dict[args.model]() # ensure it's callable
  37. model_dict = {'model': model, 'epoch_finished': 0, 'optimizer_state_dict': None, 'ema_updates': 0, 'standard': 0}
  38. return model_dict
  39. def prune(args, model, choice_dict):
  40. if not isinstance(model, nn.Module):
  41. raise TypeError("Expected model to be a PyTorch model instance")
  42. BatchNorm2d_weight = [module.weight.data.clone() for module in model.modules() if isinstance(module, nn.BatchNorm2d)]
  43. BatchNorm2d_weight_abs = torch.cat([w.abs() for w in BatchNorm2d_weight])
  44. weight_len = len(BatchNorm2d_weight)
  45. BatchNorm2d_id = [i for i in range(weight_len) for _ in range(len(BatchNorm2d_weight[i]))]
  46. id_all = torch.tensor(BatchNorm2d_id)
  47. value, index = torch.sort(BatchNorm2d_weight_abs, descending=True)
  48. boundary = int(len(index) * args.prune_ratio)
  49. prune_index = index[:boundary]
  50. prune_index, _ = torch.sort(prune_index)
  51. prune_id = id_all[prune_index]
  52. index_list = [[] for _ in range(weight_len)]
  53. for i in range(len(prune_index)):
  54. index_list[prune_id[i]].append(prune_index[i])
  55. for i, indices in enumerate(index_list):
  56. if not indices:
  57. index_list[i] = [torch.argmax(BatchNorm2d_weight[i])]
  58. index_list[i] = torch.tensor(indices) - sum(len(BatchNorm2d_weight[j]) for j in range(i))
  59. args.prune_num = [len(x) for x in index_list]
  60. prune_model = choice_dict[args.model]()
  61. index = 0
  62. for module, prune_module in zip(model.modules(), prune_model.modules()):
  63. if isinstance(module, nn.Conv2d) and index < weight_len:
  64. if max(index_list[index]) >= module.out_channels:
  65. raise IndexError("Index out of bounds for Conv2d output channels.")
  66. weight = module.weight.data.clone()[index_list[index]]
  67. if index > 0 and max(index_list[index - 1]) < module.in_channels:
  68. weight = weight[:, index_list[index - 1], :, :]
  69. prune_module.weight.data = weight
  70. if isinstance(module, nn.BatchNorm2d):
  71. if max(index_list[index]) >= module.num_features:
  72. raise IndexError("Index out of bounds for BatchNorm2d features.")
  73. prune_module.weight.data = module.weight.data.clone()[index_list[index]]
  74. prune_module.bias.data = module.bias.data.clone()[index_list[index]]
  75. prune_module.running_mean = module.running_mean.clone()[index_list[index]]
  76. prune_module.running_var = module.running_var.clone()[index_list[index]]
  77. # 打印剪枝后的BatchNorm2d层的参数维度
  78. if isinstance(prune_model, torch.nn.BatchNorm2d):
  79. print("Pruned BatchNorm2d:")
  80. print("Weight:", prune_module.weight.data.shape)
  81. print("Bias:", prune_module.bias.data.shape)
  82. print("Running Mean:", prune_module.running_mean.shape)
  83. print("Running Var:", prune_module.running_var.shape)
  84. index += 1
  85. return prune_model