import os import tensorflow as tf import tf2onnx def convert_h5_to_onnx(directory): # 遍历目录中的所有文件 for file_name in os.listdir(directory): if file_name.endswith(".h5"): h5_path = os.path.join(directory, file_name) onnx_path = os.path.join(directory, file_name.replace(".h5", ".onnx")) # 加载 h5 模型 model = tf.keras.models.load_model(h5_path) # 使用 tf2onnx 进行转换 spec = (tf.TensorSpec((None, *model.input_shape[1:]), tf.float32),) onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec) # 保存为 .onnx 文件 with open(onnx_path, "wb") as f: f.write(onnx_model.SerializeToString()) print(f"Converted {h5_path} to {onnx_path}") def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help) parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory") return parser if __name__ == "__main__": # 指定包含 .h5 文件的目录路径 args = get_args_parser().parse_args() if args.model_dir is None: raise ValueError("--model_dir must be specified") convert_h5_to_onnx(args.model_dir)