export_onnx.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. import shutil
  3. import torch
  4. import argparse
  5. from model.layer import deploy
  6. # -------------------------------------------------------------------------------------------------------------------- #
  7. parser = argparse.ArgumentParser(description='|将pt模型转为onnx,同时导出类别信息|')
  8. parser.add_argument('--weight', default='checkpoints/Alexnet/wm_embed/last.pt', type=str, help='|模型位置|')
  9. parser.add_argument('--input_size', default=32, type=int, help='|输入图片大小|')
  10. parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
  11. parser.add_argument('--batch', default=0, type=int, help='|输入图片批量,0为动态|')
  12. parser.add_argument('--sim', default=False, type=bool, help='|使用onnxsim压缩简化模型|')
  13. parser.add_argument('--device', default='cuda', type=str, help='|在哪个设备上加载模型|')
  14. parser.add_argument('--float16', default=True, type=bool, help='|转换的onnx模型数据类型,需要GPU,False时为float32|')
  15. parser.add_argument('--save_path', default='checkpoints/Alexnet/wm_embed', type=str, help='|移动存储位置|')
  16. args = parser.parse_args()
  17. args.weight = args.weight.split('.')[0] + '.pt'
  18. args.save_name = args.weight.split('.')[0] + '.onnx'
  19. # -------------------------------------------------------------------------------------------------------------------- #
  20. assert os.path.exists(args.weight), f'! 没有找到模型{args.weight} !'
  21. if args.float16:
  22. assert torch.cuda.is_available(), '! cuda不可用,无法使用float16 !'
  23. # -------------------------------------------------------------------------------------------------------------------- #
  24. def export_onnx():
  25. model_dict = torch.load(args.weight, map_location='cpu')
  26. model = model_dict['model']
  27. model = deploy(model, args.normalization)
  28. model = model.eval().half().to(args.device) if args.float16 else model.eval().float().to(args.device)
  29. input_shape = torch.rand(1, args.input_size, args.input_size, 3,
  30. dtype=torch.float16 if args.float16 else torch.float32).to(args.device)
  31. torch.onnx.export(model, input_shape, args.save_name,
  32. opset_version=12, input_names=['input'], output_names=['output'],
  33. dynamic_axes={'input': {args.batch: 'batch_size'}, 'output': {args.batch: 'batch_size'}})
  34. print(f'| 转为onnx模型成功:{args.save_name} |')
  35. if args.sim:
  36. import onnx
  37. import onnxsim
  38. model_onnx = onnx.load(args.save_name)
  39. model_simplify, check = onnxsim.simplify(model_onnx)
  40. onnx.save(model_simplify, args.save_name)
  41. print(f'| 使用onnxsim简化模型成功:{args.save_name} |')
  42. if __name__ == '__main__':
  43. export_onnx()
  44. # 移动生成的 ONNX 文件到指定文件夹
  45. # destination_folder = args.save_path
  46. # shutil.move(args.save_name, destination_folder)
  47. # print(f'| 已将 {args.save_name} 移动到 {destination_folder} 中 |')