|
@@ -0,0 +1,84 @@
|
|
|
+import os
|
|
|
+
|
|
|
+import torch
|
|
|
+import torchvision
|
|
|
+
|
|
|
+import onnx
|
|
|
+
|
|
|
+
|
|
|
+def get_args_parser(add_help=True):
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
|
|
|
+
|
|
|
+ parser.add_argument("--model", default="googlenet", type=str, help="model name")
|
|
|
+
|
|
|
+ parser.add_argument("--model_dir", default="checkpoints/googlenet", type=str, help="model checkpoints directory")
|
|
|
+ parser.add_argument(
|
|
|
+ "--num_classes", default=10, type=int, help="number of classes"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--input_size", default=224, type=int, help="input size"
|
|
|
+ )
|
|
|
+ return parser
|
|
|
+
|
|
|
+
|
|
|
+def export_onnx(args, checkpoint_file):
|
|
|
+
|
|
|
+ model = torchvision.models.get_model(args.model, weights=None, num_classes=args.num_classes)
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+
|
|
|
+ checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
|
|
+ model.load_state_dict(checkpoint["model"])
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ dummy_input = torch.randn(1, 3, args.input_size, args.input_size)
|
|
|
+
|
|
|
+ save_path = None
|
|
|
+ if checkpoint_file.endswith(".pth"):
|
|
|
+ save_path = checkpoint_file.replace(".pth", ".onnx")
|
|
|
+ if checkpoint_file.endswith(".pt"):
|
|
|
+ save_path = checkpoint_file.replace(".pt", ".onnx")
|
|
|
+
|
|
|
+ if save_path is None:
|
|
|
+ raise ValueError(f"checkpoint file:{checkpoint_file} not end with .pt or .pth")
|
|
|
+
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ torch.onnx.export(
|
|
|
+ model,
|
|
|
+ dummy_input,
|
|
|
+ save_path,
|
|
|
+ export_params=True,
|
|
|
+ opset_version=11,
|
|
|
+ do_constant_folding=False,
|
|
|
+ input_names=["input"],
|
|
|
+ output_names=["output"],
|
|
|
+ dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"checkpoint_file:{checkpoint_file}, 成功导出到 {save_path}")
|
|
|
+
|
|
|
+
|
|
|
+ model_onnx = onnx.load(save_path)
|
|
|
+ onnx.checker.check_model(model_onnx)
|
|
|
+
|
|
|
+
|
|
|
+def find_checkpoints_files(root_dir):
|
|
|
+ checkpoint_files = []
|
|
|
+
|
|
|
+ for dirpath, _, filenames in os.walk(root_dir):
|
|
|
+
|
|
|
+ for filename in filenames:
|
|
|
+ if filename.endswith('.pth') or filename.endswith('.pt'):
|
|
|
+
|
|
|
+ checkpoint_files.append(os.path.join(dirpath, filename))
|
|
|
+ return checkpoint_files
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ args = get_args_parser().parse_args()
|
|
|
+ checkpoint_files = find_checkpoints_files(args.model_dir)
|
|
|
+ for item in checkpoint_files:
|
|
|
+ export_onnx(args, item)
|