|
@@ -1,17 +1,40 @@
|
|
|
+import os
|
|
|
import tensorflow as tf
|
|
|
import tf2onnx
|
|
|
|
|
|
|
|
|
-
|
|
|
-model = tf.keras.models.load_model('saved_model')
|
|
|
+def convert_h5_to_onnx(directory):
|
|
|
+
|
|
|
+ for file_name in os.listdir(directory):
|
|
|
+ if file_name.endswith(".h5"):
|
|
|
+ h5_path = os.path.join(directory, file_name)
|
|
|
+ onnx_path = os.path.join(directory, file_name.replace(".h5", ".onnx"))
|
|
|
|
|
|
-
|
|
|
-spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
|
|
|
+
|
|
|
+ model = tf.keras.models.load_model(h5_path)
|
|
|
|
|
|
-
|
|
|
-output_path = 'vgg16_tensorflow.onnx'
|
|
|
-model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=11)
|
|
|
+
|
|
|
+ spec = (tf.TensorSpec((None, *model.input_shape[1:]), tf.float32),)
|
|
|
+ onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec)
|
|
|
|
|
|
-
|
|
|
-with open(output_path, "wb") as f:
|
|
|
- f.write(model_proto.SerializeToString())
|
|
|
+
|
|
|
+ with open(onnx_path, "wb") as f:
|
|
|
+ f.write(onnx_model.SerializeToString())
|
|
|
+
|
|
|
+ print(f"Converted {h5_path} to {onnx_path}")
|
|
|
+
|
|
|
+def get_args_parser(add_help=True):
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
|
|
|
+
|
|
|
+ parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
|
|
|
+ return parser
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+
|
|
|
+ args = get_args_parser().parse_args()
|
|
|
+
|
|
|
+ if args.model_dir is None:
|
|
|
+ raise ValueError("--model_dir must be specified")
|
|
|
+ convert_h5_to_onnx(args.model_dir)
|