run.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import numpy as np
  2. import onnx
  3. import onnxruntime
  4. import torchvision
  5. import torch
  6. from torchvision.transforms import transforms
  7. from onnx import numpy_helper
  8. transform_test = transforms.Compose([
  9. transforms.ToTensor(),
  10. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  11. ])
  12. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
  13. testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)
  14. ort_session = onnxruntime.InferenceSession('host_dnn.onnx')
  15. ort_session_embeder = onnxruntime.InferenceSession('embeder.onnx')
  16. # 获取目标模型权重
  17. model = onnx.load('host_dnn.onnx')
  18. weights = model.graph.initializer
  19. weight = numpy_helper.to_array(weights[0])
  20. weight = weight.reshape((1, -1))
  21. correct = 0
  22. total = 0
  23. # 读取随机密钥
  24. b = np.load('b.npy')
  25. ort_input_embeder = {'input': weight}
  26. ort_output_embeder = ort_session_embeder.run(['output'], ort_input_embeder)
  27. result = (ort_output_embeder == b).all()
  28. for data in testloader:
  29. inputs, labels = data
  30. inputs, labels = inputs.numpy(), labels.numpy()
  31. ort_inputs = {'input': inputs}
  32. ort_outputs = ort_session.run(['output'], ort_inputs)
  33. result = labels == np.argmax(ort_outputs)
  34. if result:
  35. correct = correct + 1
  36. total = total + 1
  37. print(f'acc={correct / total}')