12345678910111213141516171819202122232425262728293031323334353637383940 |
- import os
- import tensorflow as tf
- import tf2onnx
- def convert_h5_to_onnx(directory):
- # 遍历目录中的所有文件
- for file_name in os.listdir(directory):
- if file_name.endswith(".h5"):
- h5_path = os.path.join(directory, file_name)
- onnx_path = os.path.join(directory, file_name.replace(".h5", ".onnx"))
- # 加载 h5 模型
- model = tf.keras.models.load_model(h5_path)
- # 使用 tf2onnx 进行转换
- spec = (tf.TensorSpec((None, *model.input_shape[1:]), tf.float32),)
- onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec)
- # 保存为 .onnx 文件
- with open(onnx_path, "wb") as f:
- f.write(onnx_model.SerializeToString())
- print(f"Converted {h5_path} to {onnx_path}")
- def get_args_parser(add_help=True):
- import argparse
- parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
- parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
- return parser
- if __name__ == "__main__":
- # 指定包含 .h5 文件的目录路径
- args = get_args_parser().parse_args()
- if args.model_dir is None:
- raise ValueError("--model_dir must be specified")
- convert_h5_to_onnx(args.model_dir)
|