""" 对onnx权重文件进行规则剪枝 """ import argparse import os import onnx import numpy as np from onnx import numpy_helper def prune_weights(model_path, pruned_model, pruning_percentage=0.05): model = onnx.load(model_path) # 获取所有权重的初始化器 weight_initializers = [init for init in model.graph.initializer if init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}] # 收集所有权重的绝对值 all_weights = np.concatenate([numpy_helper.to_array(init).flatten() for init in weight_initializers]) # 计算阈值 num_weights_to_prune = int(len(all_weights) * pruning_percentage) threshold = np.partition(np.abs(all_weights), num_weights_to_prune)[num_weights_to_prune] # 剪枝权重并更新初始化器 new_initializers = [] for init in model.graph.initializer: if init.name in {node.input[1] for node in model.graph.node if node.op_type == 'Conv'}: weights = numpy_helper.to_array(init).copy() # 根据阈值剪枝权重 weights[np.abs(weights) < threshold] = 0 new_initializer = numpy_helper.from_array(weights, init.name) new_initializers.append(new_initializer) else: new_initializers.append(init) # 创建新的图 new_graph = onnx.helper.make_graph( nodes=model.graph.node, name=model.graph.name, inputs=model.graph.input, outputs=model.graph.output, initializer=new_initializers ) # 创建新的模型 new_model = onnx.helper.make_model(new_graph, producer_name='onnx-example') # 保存剪枝后的模型 onnx.save(new_model, pruned_model) def find_onnx_files(root_dir): onnx_files = [] # 遍历根目录及其子目录 for dirpath, _, filenames in os.walk(root_dir): # 查找所有以 .onnx 结尾的文件 for filename in filenames: if filename.endswith('.onnx'): # 获取完整路径并添加到列表 onnx_files.append(os.path.join(dirpath, filename)) return onnx_files if __name__ == '__main__': parser = argparse.ArgumentParser(description='模型文件剪枝工具') parser.add_argument('--target_dir', default=None, type=str, help='待剪枝的模型文件存放根目录,支持子文件夹递归处理') parser.add_argument('--pruned_saved_dir', default=None, type=str, help='剪枝模型文件保存目录,默认为None,表示与原始onnx权重文件放在同一目录下') parser.add_argument('--percent', default=0.05, type=float, help='规则剪枝百分比') args, _ = parser.parse_known_args() if args.target_dir is None: raise Exception("模型目录参数不可为空") onnx_files = find_onnx_files(args.target_dir) for onnx_file in onnx_files: if args.pruned_saved_dir: pruned_file = args.pruned_saved_dir + '/' + os.path.basename(onnx_file).replace('.onnx', '_pruned.onnx') else: pruned_file = onnx_file.replace('.onnx', '_pruned.onnx') prune_weights(model_path=onnx_file, pruned_model=pruned_file, pruning_percentage=args.percent)