1234567891011121314151617 |
- import tensorflow as tf
- import tf2onnx
- # 1. 加载保存的 TensorFlow 模型
- model = tf.keras.models.load_model('saved_model')
- # 2. 定义模型的输入签名
- spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
- # 3. 转换并保存为 ONNX 格式
- output_path = 'vgg16_tensorflow.onnx'
- model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=11)
- # 4. 保存 ONNX 模型
- with open(output_path, "wb") as f:
- f.write(model_proto.SerializeToString())
|