import os import torch import torchvision import onnx def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help) parser.add_argument("--model", default="googlenet", type=str, help="model name") # parser.add_argument("--model_dir", default=None, type=str, help="model checkpoints directory") parser.add_argument("--model_dir", default="checkpoints/googlenet", type=str, help="model checkpoints directory") parser.add_argument( "--num_classes", default=10, type=int, help="number of classes" ) parser.add_argument( "--input_size", default=224, type=int, help="input size" ) return parser def export_onnx(args, checkpoint_file): # 加载模型 model = torchvision.models.get_model(args.model, weights=None, num_classes=args.num_classes) model.eval() # 切换到评估模式 # 加载权重 checkpoint = torch.load(checkpoint_file, map_location='cpu') model.load_state_dict(checkpoint["model"]) # 定义一个随机输入张量,尺寸应符合模型输入要求 # 通常,ImageNet预训练模型的输入尺寸是 (batch_size, 3, 224, 224) dummy_input = torch.randn(1, 3, args.input_size, args.input_size) save_path = None if checkpoint_file.endswith(".pth"): save_path = checkpoint_file.replace(".pth", ".onnx") if checkpoint_file.endswith(".pt"): save_path = checkpoint_file.replace(".pt", ".onnx") if save_path is None: raise ValueError(f"checkpoint file:{checkpoint_file} not end with .pt or .pth") # 导出模型到 ONNX 格式 with torch.no_grad(): torch.onnx.export( model, dummy_input, save_path, export_params=True, opset_version=11, do_constant_folding=False, # 确保这个参数一定是False,如果不为False,导出的onnx权重与原始权重数值不一致,导致白盒水印提取失败 input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}} ) print(f"checkpoint_file:{checkpoint_file}, 成功导出到 {save_path}") # Checks model_onnx = onnx.load(save_path) # load onnx model onnx.checker.check_model(model_onnx) # check onnx model def find_checkpoints_files(root_dir): checkpoint_files = [] # 遍历根目录及其子目录 for dirpath, _, filenames in os.walk(root_dir): # 查找所有以 .onnx 结尾的文件 for filename in filenames: if filename.endswith('.pth') or filename.endswith('.pt'): # 获取完整路径并添加到列表 checkpoint_files.append(os.path.join(dirpath, filename)) return checkpoint_files if __name__ == '__main__': args = get_args_parser().parse_args() checkpoint_files = find_checkpoints_files(args.model_dir) for item in checkpoint_files: export_onnx(args, item)