export_onnx.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. import torchvision
  3. import onnx
  4. def get_args_parser(add_help=True):
  5. import argparse
  6. parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
  7. parser.add_argument("--model", default="resnet18", type=str, help="model name")
  8. parser.add_argument(
  9. "--num_classes", default=10, type=int, help="number of classes"
  10. )
  11. parser.add_argument(
  12. "--input_size", default=224, type=int, help="input size"
  13. )
  14. parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
  15. parser.add_argument("--save_path", default=None, type=str, help="onnx file save path")
  16. return parser
  17. def export_onnx(args):
  18. # 加载模型
  19. model = torchvision.models.get_model(args.model, weights=None, num_classes=args.num_classes)
  20. model.eval() # 切换到评估模式
  21. # 加载权重
  22. checkpoint = torch.load(args.weights, map_location='cpu')
  23. model.load_state_dict(checkpoint["model"])
  24. # 定义一个随机输入张量,尺寸应符合模型输入要求
  25. # 通常,ImageNet预训练模型的输入尺寸是 (batch_size, 3, 224, 224)
  26. dummy_input = torch.randn(1, 3, args.input_size, args.input_size)
  27. # 导出模型到 ONNX 格式
  28. with torch.no_grad():
  29. torch.onnx.export(
  30. model,
  31. dummy_input,
  32. args.save_path,
  33. export_params=True,
  34. opset_version=11,
  35. do_constant_folding=False, # 确保这个参数一定是False,如果不为False,导出的onnx权重与原始权重数值不一致,导致白盒水印提取失败
  36. input_names=["input"],
  37. output_names=["output"],
  38. dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
  39. )
  40. print("模型成功导出为 ONNX 格式!")
  41. # Checks
  42. model_onnx = onnx.load(args.save_path) # load onnx model
  43. onnx.checker.check_model(model_onnx) # check onnx model
  44. if __name__ == '__main__':
  45. args = get_args_parser().parse_args()
  46. export_onnx(args)