predict.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import numpy as np
  2. import onnxruntime as ort
  3. from PIL import Image
  4. import torchvision.transforms as T
  5. # 读取并预处理图片
  6. # def process_image(image_path):
  7. # image = Image.open(image_path).convert("RGB")
  8. # preprocess = T.Compose([
  9. # T.Resize((224, 224)),
  10. # T.ToTensor(),
  11. # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  12. # ])
  13. # return preprocess(image)
  14. def process_image(image_path):
  15. # 打开图像并转换为RGB
  16. image = Image.open(image_path).convert("RGB")
  17. # 调整图像大小
  18. image = image.resize((224, 224))
  19. # 转换为numpy数组并归一化
  20. image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
  21. # 进行标准化
  22. mean = np.array([0.485, 0.456, 0.406])
  23. std = np.array([0.229, 0.224, 0.225])
  24. image_array = (image_array - mean) / std
  25. image_array = image_array.transpose((2, 0, 1)).copy()
  26. return image_array.astype(np.float32)
  27. # 推理函数
  28. def infer(session, image):
  29. input_name = session.get_inputs()[0].name
  30. # ONNX 模型需要输入为 numpy 数组
  31. # image = image.numpy() # 转换为 numpy 数组
  32. image = np.expand_dims(image, axis=0) # 增加批次维度
  33. output = session.run(None, {input_name: image})
  34. return output
  35. if __name__ == '__main__':
  36. # 加载模型
  37. model_path = 'blackbox_models/vgg16/vgg16.onnx'
  38. session = ort.InferenceSession(model_path)
  39. # 指定要推理的图片路径
  40. image_path = 'blackbox_models/vgg16/trigger/images/0/ILSVRC2012_val_00000537.JPEG'
  41. # 处理图像
  42. processed_image = process_image(image_path)
  43. # 进行推理
  44. predictions = infer(session, processed_image)
  45. # 输出预测结果
  46. print(predictions)
  47. cls = np.argmax(predictions[0])
  48. print(cls)