import numpy as np import onnx import onnxruntime import torchvision import torch from torchvision.transforms import transforms from onnx import numpy_helper transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False) ort_session = onnxruntime.InferenceSession('host_dnn.onnx') ort_session_embeder = onnxruntime.InferenceSession('embeder.onnx') # 获取目标模型权重 model = onnx.load('host_dnn.onnx') weights = model.graph.initializer weight = numpy_helper.to_array(weights[0]) weight = weight.reshape((1, -1)) correct = 0 total = 0 # 读取随机密钥 b = np.load('b.npy') ort_input_embeder = {'input': weight} ort_output_embeder = ort_session_embeder.run(['output'], ort_input_embeder) result = (ort_output_embeder == b).all() for data in testloader: inputs, labels = data inputs, labels = inputs.numpy(), labels.numpy() ort_inputs = {'input': inputs} ort_outputs = ort_session.run(['output'], ort_inputs) result = labels == np.argmax(ort_outputs) if result: correct = correct + 1 total = total + 1 print(f'acc={correct / total}')