import torch import torchvision import onnx def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser(description="model weight transport to onnx", add_help=add_help) parser.add_argument("--model", default="resnet18", type=str, help="model name") parser.add_argument( "--num_classes", default=10, type=int, help="number of classes" ) parser.add_argument( "--input_size", default=224, type=int, help="input size" ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--save_path", default=None, type=str, help="onnx file save path") return parser def export_onnx(args): # 加载模型 model = torchvision.models.get_model(args.model, weights=None, num_classes=args.num_classes) model.eval() # 切换到评估模式 # 加载权重 checkpoint = torch.load(args.weights, map_location='cpu') model.load_state_dict(checkpoint["model"]) # 定义一个随机输入张量,尺寸应符合模型输入要求 # 通常,ImageNet预训练模型的输入尺寸是 (batch_size, 3, 224, 224) dummy_input = torch.randn(1, 3, args.input_size, args.input_size) # 导出模型到 ONNX 格式 with torch.no_grad(): torch.onnx.export( model, dummy_input, args.save_path, export_params=True, opset_version=11, do_constant_folding=False, # 确保这个参数一定是False,如果不为False,导出的onnx权重与原始权重数值不一致,导致白盒水印提取失败 input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}} ) print("模型成功导出为 ONNX 格式!") # Checks model_onnx = onnx.load(args.save_path) # load onnx model onnx.checker.check_model(model_onnx) # check onnx model if __name__ == '__main__': args = get_args_parser().parse_args() export_onnx(args)