浏览代码

新增批量将权重文件导出为onnx脚本

liyan 7 月之前
父节点
当前提交
5eb94a6725
共有 1 个文件被更改,包括 84 次插入0 次删除
  1. 84 0
      export_onnx_batch.py

+ 84 - 0
export_onnx_batch.py

@@ -0,0 +1,84 @@
+import os
+
+import torch
+import torchvision
+
+import onnx
+
+
+def get_args_parser(add_help=True):
+    import argparse
+
+    parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help)
+
+    parser.add_argument("--model", default="googlenet", type=str, help="model name")
+    # parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory")
+    parser.add_argument("--model_dir", default="checkpoints/googlenet", type=str, help="model checkpoints directory")
+    parser.add_argument(
+        "--num_classes", default=10, type=int, help="number of classes"
+    )
+    parser.add_argument(
+        "--input_size", default=224, type=int, help="input size"
+    )
+    return parser
+
+
+def export_onnx(args, checkpoint_file):
+    # 加载模型
+    model = torchvision.models.get_model(args.model, weights=None, num_classes=args.num_classes)
+    model.eval()  # 切换到评估模式
+
+    # 加载权重
+    checkpoint = torch.load(checkpoint_file, map_location='cpu')
+    model.load_state_dict(checkpoint["model"])
+
+    # 定义一个随机输入张量,尺寸应符合模型输入要求
+    # 通常,ImageNet预训练模型的输入尺寸是 (batch_size, 3, 224, 224)
+    dummy_input = torch.randn(1, 3, args.input_size, args.input_size)
+
+    save_path = None
+    if checkpoint_file.endswith(".pth"):
+        save_path = checkpoint_file.replace(".pth", ".onnx")
+    if checkpoint_file.endswith(".pt"):
+        save_path = checkpoint_file.replace(".pt", ".onnx")
+
+    if save_path is None:
+        raise ValueError(f"checkpoint file:{checkpoint_file} not end with .pt or .pth")
+
+    # 导出模型到 ONNX 格式
+    with torch.no_grad():
+        torch.onnx.export(
+            model,
+            dummy_input,
+            save_path,
+            export_params=True,
+            opset_version=11,
+            do_constant_folding=False,  # 确保这个参数一定是False,如果不为False,导出的onnx权重与原始权重数值不一致,导致白盒水印提取失败
+            input_names=["input"],
+            output_names=["output"],
+            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
+        )
+
+    print(f"checkpoint_file:{checkpoint_file}, 成功导出到 {save_path}")
+
+    # Checks
+    model_onnx = onnx.load(save_path)  # load onnx model
+    onnx.checker.check_model(model_onnx)  # check onnx model
+
+
+def find_checkpoints_files(root_dir):
+    checkpoint_files = []
+    # 遍历根目录及其子目录
+    for dirpath, _, filenames in os.walk(root_dir):
+        # 查找所有以 .onnx 结尾的文件
+        for filename in filenames:
+            if filename.endswith('.pth') or filename.endswith('.pt'):
+                # 获取完整路径并添加到列表
+                checkpoint_files.append(os.path.join(dirpath, filename))
+    return checkpoint_files
+
+if __name__ == '__main__':
+    args = get_args_parser().parse_args()
+    checkpoint_files = find_checkpoints_files(args.model_dir)
+    for item in checkpoint_files:
+        export_onnx(args, item)