predict_batch.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import onnxruntime as ort
  2. import numpy as np
  3. import os
  4. from PIL import Image
  5. # 读取并预处理图片
  6. def process_image(image_path):
  7. import torchvision.transforms as T
  8. image = Image.open(image_path).convert("RGB")
  9. preprocess = T.Compose([
  10. T.Resize((224, 224)),
  11. T.ToTensor(),
  12. T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  13. ])
  14. return preprocess(image).numpy()
  15. # def process_image(image_path):
  16. # # 打开图像并转换为RGB
  17. # image = Image.open(image_path).convert("RGB")
  18. #
  19. # # 调整图像大小
  20. # image = image.resize((224, 224))
  21. #
  22. # # 转换为numpy数组并归一化
  23. # image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
  24. #
  25. # # 进行标准化
  26. # mean = np.array([0.485, 0.456, 0.406])
  27. # std = np.array([0.229, 0.224, 0.225])
  28. # image_array = (image_array - mean) / std
  29. # image_array = image_array.transpose((2, 0, 1)).copy()
  30. #
  31. # return image_array.astype(np.float32)
  32. def batch_predict_images(model_path, image_dir, target_class, threshold=0.6, batch_size=10):
  33. """
  34. 对指定图片文件夹图片进行批量检测
  35. :param model_path: onnx模型文件路径
  36. :param image_dir: 待推理的图像文件夹
  37. :param target_class: 目标分类
  38. :param threshold: 通过测试阈值
  39. :param batch_size: 每批图片数量
  40. :return: 检测结果
  41. """
  42. # 加载 ONNX 模型
  43. session = ort.InferenceSession(model_path)
  44. image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
  45. results = {}
  46. input_name = session.get_inputs()[0].name
  47. for i in range(0, len(image_files), batch_size):
  48. correct_predictions = 0
  49. total_predictions = 0
  50. batch_files = image_files[i:i + batch_size]
  51. batch_images = []
  52. for image_file in batch_files:
  53. image_path = os.path.join(image_dir, image_file)
  54. image = process_image(image_path)
  55. batch_images.append(image)
  56. # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
  57. batch_images = np.stack(batch_images)
  58. # 执行预测
  59. outputs = session.run(None, {input_name: batch_images})
  60. # 提取预测结果
  61. for j, image_file in enumerate(batch_files):
  62. predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
  63. results[image_file] = predicted_class
  64. total_predictions += 1
  65. # 比较预测结果与目标分类
  66. if predicted_class == target_class:
  67. correct_predictions += 1
  68. print(f"Predicted batch {i // batch_size + 1}")
  69. # 计算准确率
  70. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
  71. print(f"Accuracy: {accuracy * 100:.2f}%")
  72. if accuracy > threshold:
  73. return True
  74. return False
  75. # 使用示例
  76. image_dir = 'blackbox_models/vgg16/trigger/images/2'
  77. model_path = 'blackbox_models/vgg16/vgg16.onnx'
  78. target_class = 2 # 替换为您要检查的目标分类
  79. batch_predict_images(model_path, image_dir, target_class)