123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- import onnxruntime as ort
- import numpy as np
- import os
- from PIL import Image
- # 读取并预处理图片
- def process_image(image_path):
- import torchvision.transforms as T
- image = Image.open(image_path).convert("RGB")
- preprocess = T.Compose([
- T.Resize((224, 224)),
- T.ToTensor(),
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- return preprocess(image).numpy()
- # def process_image(image_path):
- # # 打开图像并转换为RGB
- # image = Image.open(image_path).convert("RGB")
- #
- # # 调整图像大小
- # image = image.resize((224, 224))
- #
- # # 转换为numpy数组并归一化
- # image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1]
- #
- # # 进行标准化
- # mean = np.array([0.485, 0.456, 0.406])
- # std = np.array([0.229, 0.224, 0.225])
- # image_array = (image_array - mean) / std
- # image_array = image_array.transpose((2, 0, 1)).copy()
- #
- # return image_array.astype(np.float32)
- def batch_predict_images(model_path, image_dir, target_class, threshold=0.6, batch_size=10):
- """
- 对指定图片文件夹图片进行批量检测
- :param model_path: onnx模型文件路径
- :param image_dir: 待推理的图像文件夹
- :param target_class: 目标分类
- :param threshold: 通过测试阈值
- :param batch_size: 每批图片数量
- :return: 检测结果
- """
- # 加载 ONNX 模型
- session = ort.InferenceSession(model_path)
- image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
- results = {}
- input_name = session.get_inputs()[0].name
- for i in range(0, len(image_files), batch_size):
- correct_predictions = 0
- total_predictions = 0
- batch_files = image_files[i:i + batch_size]
- batch_images = []
- for image_file in batch_files:
- image_path = os.path.join(image_dir, image_file)
- image = process_image(image_path)
- batch_images.append(image)
- # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度
- batch_images = np.stack(batch_images)
- # 执行预测
- outputs = session.run(None, {input_name: batch_images})
- # 提取预测结果
- for j, image_file in enumerate(batch_files):
- predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率
- results[image_file] = predicted_class
- total_predictions += 1
- # 比较预测结果与目标分类
- if predicted_class == target_class:
- correct_predictions += 1
- print(f"Predicted batch {i // batch_size + 1}")
- # 计算准确率
- accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
- print(f"Accuracy: {accuracy * 100:.2f}%")
- if accuracy > threshold:
- return True
- return False
- # 使用示例
- image_dir = 'blackbox_models/vgg16/trigger/images/2'
- model_path = 'blackbox_models/vgg16/vgg16.onnx'
- target_class = 2 # 替换为您要检查的目标分类
- batch_predict_images(model_path, image_dir, target_class)
|