123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import os
- import torch
- import torchvision
- import onnx
- def get_args_parser(add_help=True):
- import argparse
- parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
- parser.add_argument("--model", default="googlenet", type=str, help="model name")
-
- parser.add_argument("--model_dir", default="checkpoints/googlenet", type=str, help="model checkpoints directory")
- parser.add_argument(
- "--num_classes", default=10, type=int, help="number of classes"
- )
- parser.add_argument(
- "--input_size", default=224, type=int, help="input size"
- )
- return parser
- def export_onnx(args, checkpoint_file):
-
- model = torchvision.models.get_model(args.model, weights=None, num_classes=args.num_classes)
- model.eval()
-
- checkpoint = torch.load(checkpoint_file, map_location='cpu')
- model.load_state_dict(checkpoint["model"])
-
-
- dummy_input = torch.randn(1, 3, args.input_size, args.input_size)
- save_path = None
- if checkpoint_file.endswith(".pth"):
- save_path = checkpoint_file.replace(".pth", ".onnx")
- if checkpoint_file.endswith(".pt"):
- save_path = checkpoint_file.replace(".pt", ".onnx")
- if save_path is None:
- raise ValueError(f"checkpoint file:{checkpoint_file} not end with .pt or .pth")
-
- with torch.no_grad():
- torch.onnx.export(
- model,
- dummy_input,
- save_path,
- export_params=True,
- opset_version=11,
- do_constant_folding=False,
- input_names=["input"],
- output_names=["output"],
- dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
- )
- print(f"checkpoint_file:{checkpoint_file}, 成功导出到 {save_path}")
-
- model_onnx = onnx.load(save_path)
- onnx.checker.check_model(model_onnx)
- def find_checkpoints_files(root_dir):
- checkpoint_files = []
-
- for dirpath, _, filenames in os.walk(root_dir):
-
- for filename in filenames:
- if filename.endswith('.pth') or filename.endswith('.pt'):
-
- checkpoint_files.append(os.path.join(dirpath, filename))
- return checkpoint_files
- if __name__ == '__main__':
- args = get_args_parser().parse_args()
- checkpoint_files = find_checkpoints_files(args.model_dir)
- for item in checkpoint_files:
- export_onnx(args, item)
|