Explorar o código

新增mindspore框架训练流程导出为onnx文件

liyan hai 11 meses
pai
achega
c901071fbc
Modificáronse 1 ficheiros con 4 adicións e 0 borrados
  1. 4 0
      tests/train.py

+ 4 - 0
tests/train.py

@@ -101,6 +101,9 @@ def test(model, dataset, loss_fn):
 def save(model, save_path):
     mindspore.save_checkpoint(model, save_path)
 
+def export_onnx(model):
+    mindspore.export(model, train_dataset, file_name='./run/train/AlexNet.onnx', file_format='ONNX')
+
 
 if __name__ == '__main__':
     epochs = 50
@@ -110,4 +113,5 @@ if __name__ == '__main__':
         train(model, train_dataset)
         test(model, test_dataset, loss_fn)
         save(model, save_path)
+        export_onnx(model)
     print("Done!")