ソースを参照

新增mindspore框架训练流程导出为onnx文件

liyan 11 ヶ月 前
コミット
c901071fbc
1 ファイル変更4 行追加0 行削除
  1. 4 0
      tests/train.py

+ 4 - 0
tests/train.py

@@ -101,6 +101,9 @@ def test(model, dataset, loss_fn):
 def save(model, save_path):
     mindspore.save_checkpoint(model, save_path)
 
+def export_onnx(model):
+    mindspore.export(model, train_dataset, file_name='./run/train/AlexNet.onnx', file_format='ONNX')
+
 
 if __name__ == '__main__':
     epochs = 50
@@ -110,4 +113,5 @@ if __name__ == '__main__':
         train(model, train_dataset)
         test(model, test_dataset, loss_fn)
         save(model, save_path)
+        export_onnx(model)
     print("Done!")