|
@@ -101,6 +101,9 @@ def test(model, dataset, loss_fn):
|
|
def save(model, save_path):
|
|
def save(model, save_path):
|
|
mindspore.save_checkpoint(model, save_path)
|
|
mindspore.save_checkpoint(model, save_path)
|
|
|
|
|
|
|
|
+def export_onnx(model):
|
|
|
|
+ mindspore.export(model, train_dataset, file_name='./run/train/AlexNet.onnx', file_format='ONNX')
|
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
epochs = 50
|
|
epochs = 50
|
|
@@ -110,4 +113,5 @@ if __name__ == '__main__':
|
|
train(model, train_dataset)
|
|
train(model, train_dataset)
|
|
test(model, test_dataset, loss_fn)
|
|
test(model, test_dataset, loss_fn)
|
|
save(model, save_path)
|
|
save(model, save_path)
|
|
|
|
+ export_onnx(model)
|
|
print("Done!")
|
|
print("Done!")
|