import os
import tensorflow as tf
import tf2onnx


def convert_h5_to_onnx(directory):
    # 遍历目录中的所有文件
    for file_name in os.listdir(directory):
        if file_name.endswith(".h5"):
            h5_path = os.path.join(directory, file_name)
            onnx_path = os.path.join(directory, file_name.replace(".h5", ".onnx"))

            # 加载 h5 模型
            model = tf.keras.models.load_model(h5_path)

            # 使用 tf2onnx 进行转换
            spec = (tf.TensorSpec((None, *model.input_shape[1:]), tf.float32),)
            onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec)

            # 保存为 .onnx 文件
            with open(onnx_path, "wb") as f:
                f.write(onnx_model.SerializeToString())

            print(f"Converted {h5_path} to {onnx_path}")

def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)

    parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
    return parser

if __name__ == "__main__":
    # 指定包含 .h5 文件的目录路径
    args = get_args_parser().parse_args()

    if args.model_dir is None:
        raise ValueError("--model_dir must be specified")
    convert_h5_to_onnx(args.model_dir)