export_onnx_batch.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import torch
  3. import torchvision
  4. import onnx
  5. def get_args_parser(add_help=True):
  6. import argparse
  7. parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
  8. parser.add_argument("--model", default="googlenet", type=str, help="model name")
  9. # parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
  10. parser.add_argument("--model_dir", default="checkpoints/googlenet", type=str, help="model checkpoints directory")
  11. parser.add_argument(
  12. "--num_classes", default=10, type=int, help="number of classes"
  13. )
  14. parser.add_argument(
  15. "--input_size", default=224, type=int, help="input size"
  16. )
  17. return parser
  18. def export_onnx(args, checkpoint_file):
  19. # 加载模型
  20. model = torchvision.models.get_model(args.model, weights=None, num_classes=args.num_classes)
  21. model.eval() # 切换到评估模式
  22. # 加载权重
  23. checkpoint = torch.load(checkpoint_file, map_location='cpu')
  24. model.load_state_dict(checkpoint["model"])
  25. # 定义一个随机输入张量,尺寸应符合模型输入要求
  26. # 通常,ImageNet预训练模型的输入尺寸是 (batch_size, 3, 224, 224)
  27. dummy_input = torch.randn(1, 3, args.input_size, args.input_size)
  28. save_path = None
  29. if checkpoint_file.endswith(".pth"):
  30. save_path = checkpoint_file.replace(".pth", ".onnx")
  31. if checkpoint_file.endswith(".pt"):
  32. save_path = checkpoint_file.replace(".pt", ".onnx")
  33. if save_path is None:
  34. raise ValueError(f"checkpoint file:{checkpoint_file} not end with .pt or .pth")
  35. # 导出模型到 ONNX 格式
  36. with torch.no_grad():
  37. torch.onnx.export(
  38. model,
  39. dummy_input,
  40. save_path,
  41. export_params=True,
  42. opset_version=11,
  43. do_constant_folding=False, # 确保这个参数一定是False,如果不为False,导出的onnx权重与原始权重数值不一致,导致白盒水印提取失败
  44. input_names=["input"],
  45. output_names=["output"],
  46. dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
  47. )
  48. print(f"checkpoint_file:{checkpoint_file}, 成功导出到 {save_path}")
  49. # Checks
  50. model_onnx = onnx.load(save_path) # load onnx model
  51. onnx.checker.check_model(model_onnx) # check onnx model
  52. def find_checkpoints_files(root_dir):
  53. checkpoint_files = []
  54. # 遍历根目录及其子目录
  55. for dirpath, _, filenames in os.walk(root_dir):
  56. # 查找所有以 .onnx 结尾的文件
  57. for filename in filenames:
  58. if filename.endswith('.pth') or filename.endswith('.pt'):
  59. # 获取完整路径并添加到列表
  60. checkpoint_files.append(os.path.join(dirpath, filename))
  61. return checkpoint_files
  62. if __name__ == '__main__':
  63. args = get_args_parser().parse_args()
  64. checkpoint_files = find_checkpoints_files(args.model_dir)
  65. for item in checkpoint_files:
  66. export_onnx(args, item)