123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- import numpy as np
- import onnx
- import onnxruntime
- import torchvision
- import torch
- from torchvision.transforms import transforms
- from onnx import numpy_helper
- transform_test = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- ])
- testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
- testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)
- ort_session = onnxruntime.InferenceSession('host_dnn.onnx')
- ort_session_embeder = onnxruntime.InferenceSession('embeder.onnx')
- # 获取目标模型权重
- model = onnx.load('host_dnn.onnx')
- weights = model.graph.initializer
- weight = numpy_helper.to_array(weights[0])
- weight = weight.reshape((1, -1))
- correct = 0
- total = 0
- # 读取随机密钥
- b = np.load('b.npy')
- ort_input_embeder = {'input': weight}
- ort_output_embeder = ort_session_embeder.run(['output'], ort_input_embeder)
- result = (ort_output_embeder == b).all()
- for data in testloader:
- inputs, labels = data
- inputs, labels = inputs.numpy(), labels.numpy()
- ort_inputs = {'input': inputs}
- ort_outputs = ort_session.run(['output'], ort_inputs)
- result = labels == np.argmax(ort_outputs)
- if result:
- correct = correct + 1
- total = total + 1
- print(f'acc={correct / total}')
|