export_onnx.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import os
  2. import tensorflow as tf
  3. import tf2onnx
  4. def convert_h5_to_onnx(directory):
  5. # 遍历目录中的所有文件
  6. for file_name in os.listdir(directory):
  7. if file_name.endswith(".h5"):
  8. h5_path = os.path.join(directory, file_name)
  9. onnx_path = os.path.join(directory, file_name.replace(".h5", ".onnx"))
  10. # 加载 h5 模型
  11. model = tf.keras.models.load_model(h5_path)
  12. # 使用 tf2onnx 进行转换
  13. spec = (tf.TensorSpec((None, *model.input_shape[1:]), tf.float32),)
  14. onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec)
  15. # 保存为 .onnx 文件
  16. with open(onnx_path, "wb") as f:
  17. f.write(onnx_model.SerializeToString())
  18. print(f"Converted {h5_path} to {onnx_path}")
  19. def get_args_parser(add_help=True):
  20. import argparse
  21. parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
  22. parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
  23. return parser
  24. if __name__ == "__main__":
  25. # 指定包含 .h5 文件的目录路径
  26. args = get_args_parser().parse_args()
  27. if args.model_dir is None:
  28. raise ValueError("--model_dir must be specified")
  29. convert_h5_to_onnx(args.model_dir)