Selaa lähdekoodia

修改导出onnx脚本

liyan 7 kuukautta sitten
vanhempi
commit
1c26fb1de4
1 muutettua tiedostoa jossa 33 lisäystä ja 10 poistoa
  1. 33 10
      export_onnx.py

+ 33 - 10
export_onnx.py

@@ -1,17 +1,40 @@
+import os
 import tensorflow as tf
 import tf2onnx
 
 
-# 1. 加载保存的 TensorFlow 模型
-model = tf.keras.models.load_model('saved_model')
+def convert_h5_to_onnx(directory):
+    # 遍历目录中的所有文件
+    for file_name in os.listdir(directory):
+        if file_name.endswith(".h5"):
+            h5_path = os.path.join(directory, file_name)
+            onnx_path = os.path.join(directory, file_name.replace(".h5", ".onnx"))
 
-# 2. 定义模型的输入签名
-spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
+            # 加载 h5 模型
+            model = tf.keras.models.load_model(h5_path)
 
-# 3. 转换并保存为 ONNX 格式
-output_path = 'vgg16_tensorflow.onnx'
-model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=11)
+            # 使用 tf2onnx 进行转换
+            spec = (tf.TensorSpec((None, *model.input_shape[1:]), tf.float32),)
+            onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec)
 
-# 4. 保存 ONNX 模型
-with open(output_path, "wb") as f:
-    f.write(model_proto.SerializeToString())
+            # 保存为 .onnx 文件
+            with open(onnx_path, "wb") as f:
+                f.write(onnx_model.SerializeToString())
+
+            print(f"Converted {h5_path} to {onnx_path}")
+
+def get_args_parser(add_help=True):
+    import argparse
+
+    parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
+
+    parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
+    return parser
+
+if __name__ == "__main__":
+    # 指定包含 .h5 文件的目录路径
+    args = get_args_parser().parse_args()
+
+    if args.model_dir is None:
+        raise ValueError("--model_dir must be specified")
+    convert_h5_to_onnx(args.model_dir)