prune_tool.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """
  2. 对onnx权重文件进行规则剪枝
  3. """
  4. import argparse
  5. import os
  6. import onnx
  7. import numpy as np
  8. from onnx import numpy_helper
  9. def prune_weights(model_path, pruned_model, pruning_percentage=0.05):
  10. model = onnx.load(model_path)
  11. # 获取所有权重的初始化器
  12. weight_initializers = [init for init in model.graph.initializer if
  13. init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}]
  14. # 收集所有权重的绝对值
  15. all_weights = np.concatenate([numpy_helper.to_array(init).flatten() for init in weight_initializers])
  16. # 计算阈值
  17. num_weights_to_prune = int(len(all_weights) * pruning_percentage)
  18. threshold = np.partition(np.abs(all_weights), num_weights_to_prune)[num_weights_to_prune]
  19. # 剪枝权重并更新初始化器
  20. new_initializers = []
  21. for init in model.graph.initializer:
  22. if init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}:
  23. weights = numpy_helper.to_array(init).copy()
  24. # 根据阈值剪枝权重
  25. weights[np.abs(weights) < threshold] = 0
  26. new_initializer = numpy_helper.from_array(weights, init.name)
  27. new_initializers.append(new_initializer)
  28. else:
  29. new_initializers.append(init)
  30. # 创建新的图
  31. new_graph = onnx.helper.make_graph(
  32. nodes=model.graph.node,
  33. name=model.graph.name,
  34. inputs=model.graph.input,
  35. outputs=model.graph.output,
  36. initializer=new_initializers
  37. )
  38. # 创建新的模型
  39. new_model = onnx.helper.make_model(new_graph, producer_name='onnx-example')
  40. # 保存剪枝后的模型
  41. onnx.save(new_model, pruned_model)
  42. def find_onnx_files(root_dir):
  43. onnx_files = []
  44. # 遍历根目录及其子目录
  45. for dirpath, _, filenames in os.walk(root_dir):
  46. # 查找所有以 .onnx 结尾的文件
  47. for filename in filenames:
  48. if filename.endswith('.onnx'):
  49. # 获取完整路径并添加到列表
  50. onnx_files.append(os.path.join(dirpath, filename))
  51. return onnx_files
  52. if __name__ == '__main__':
  53. parser = argparse.ArgumentParser(description='模型文件剪枝工具')
  54. parser.add_argument('--target_dir', default=None, type=str, help='待剪枝的模型文件存放根目录,支持子文件夹递归处理')
  55. parser.add_argument('--pruned_saved_dir', default=None, type=str, help='剪枝模型文件保存目录,默认为None,表示与原始onnx权重文件放在同一目录下')
  56. parser.add_argument('--percent', default=0.05, type=float, help='规则剪枝百分比')
  57. args, _ = parser.parse_known_args()
  58. if args.target_dir is None:
  59. raise Exception("模型目录参数不可为空")
  60. onnx_files = find_onnx_files(args.target_dir)
  61. for onnx_file in onnx_files:
  62. if args.pruned_saved_dir:
  63. pruned_file = args.pruned_saved_dir + '/' + os.path.basename(onnx_file).replace('.onnx', '_pruned.onnx')
  64. else:
  65. pruned_file = onnx_file.replace('.onnx', '_pruned.onnx')
  66. prune_weights(model_path=onnx_file, pruned_model=pruned_file, pruning_percentage=args.percent)