|
@@ -1,17 +1,40 @@
|
|
|
+import os
|
|
|
import tensorflow as tf
|
|
|
import tf2onnx
|
|
|
|
|
|
|
|
|
-# 1. 加载保存的 TensorFlow 模型
|
|
|
-model = tf.keras.models.load_model('saved_model')
|
|
|
+def convert_h5_to_onnx(directory):
|
|
|
+ # 遍历目录中的所有文件
|
|
|
+ for file_name in os.listdir(directory):
|
|
|
+ if file_name.endswith(".h5"):
|
|
|
+ h5_path = os.path.join(directory, file_name)
|
|
|
+ onnx_path = os.path.join(directory, file_name.replace(".h5", ".onnx"))
|
|
|
|
|
|
-# 2. 定义模型的输入签名
|
|
|
-spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
|
|
|
+ # 加载 h5 模型
|
|
|
+ model = tf.keras.models.load_model(h5_path)
|
|
|
|
|
|
-# 3. 转换并保存为 ONNX 格式
|
|
|
-output_path = 'vgg16_tensorflow.onnx'
|
|
|
-model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=11)
|
|
|
+ # 使用 tf2onnx 进行转换
|
|
|
+ spec = (tf.TensorSpec((None, *model.input_shape[1:]), tf.float32),)
|
|
|
+ onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec)
|
|
|
|
|
|
-# 4. 保存 ONNX 模型
|
|
|
-with open(output_path, "wb") as f:
|
|
|
- f.write(model_proto.SerializeToString())
|
|
|
+ # 保存为 .onnx 文件
|
|
|
+ with open(onnx_path, "wb") as f:
|
|
|
+ f.write(onnx_model.SerializeToString())
|
|
|
+
|
|
|
+ print(f"Converted {h5_path} to {onnx_path}")
|
|
|
+
|
|
|
+def get_args_parser(add_help=True):
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
|
|
|
+
|
|
|
+ parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
|
|
|
+ return parser
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 指定包含 .h5 文件的目录路径
|
|
|
+ args = get_args_parser().parse_args()
|
|
|
+
|
|
|
+ if args.model_dir is None:
|
|
|
+ raise ValueError("--model_dir must be specified")
|
|
|
+ convert_h5_to_onnx(args.model_dir)
|