export_onnx.py 495 B

1234567891011121314151617
  1. import tensorflow as tf
  2. import tf2onnx
  3. # 1. 加载保存的 TensorFlow 模型
  4. model = tf.keras.models.load_model('saved_model')
  5. # 2. 定义模型的输入签名
  6. spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
  7. # 3. 转换并保存为 ONNX 格式
  8. output_path = 'vgg16_tensorflow.onnx'
  9. model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=11)
  10. # 4. 保存 ONNX 模型
  11. with open(output_path, "wb") as f:
  12. f.write(model_proto.SerializeToString())