import onnxruntime as ort import numpy as np import os from PIL import Image # 读取并预处理图片 def process_image(image_path): import torchvision.transforms as T image = Image.open(image_path).convert("RGB") preprocess = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return preprocess(image).numpy() # def process_image(image_path): # # 打开图像并转换为RGB # image = Image.open(image_path).convert("RGB") # # # 调整图像大小 # image = image.resize((224, 224)) # # # 转换为numpy数组并归一化 # image_array = np.array(image) / 255.0 # 将像素值缩放到[0, 1] # # # 进行标准化 # mean = np.array([0.485, 0.456, 0.406]) # std = np.array([0.229, 0.224, 0.225]) # image_array = (image_array - mean) / std # image_array = image_array.transpose((2, 0, 1)).copy() # # return image_array.astype(np.float32) def batch_predict_images(model_path, image_dir, target_class, threshold=0.6, batch_size=10): """ 对指定图片文件夹图片进行批量检测 :param model_path: onnx模型文件路径 :param image_dir: 待推理的图像文件夹 :param target_class: 目标分类 :param threshold: 通过测试阈值 :param batch_size: 每批图片数量 :return: 检测结果 """ # 加载 ONNX 模型 session = ort.InferenceSession(model_path) image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] results = {} input_name = session.get_inputs()[0].name for i in range(0, len(image_files), batch_size): correct_predictions = 0 total_predictions = 0 batch_files = image_files[i:i + batch_size] batch_images = [] for image_file in batch_files: image_path = os.path.join(image_dir, image_file) image = process_image(image_path) batch_images.append(image) # 将批次图片堆叠成 (batch_size, 3, 224, 224) 维度 batch_images = np.stack(batch_images) # 执行预测 outputs = session.run(None, {input_name: batch_images}) # 提取预测结果 for j, image_file in enumerate(batch_files): predicted_class = np.argmax(outputs[0][j]) # 假设输出是每类的概率 results[image_file] = predicted_class total_predictions += 1 # 比较预测结果与目标分类 if predicted_class == target_class: correct_predictions += 1 print(f"Predicted batch {i // batch_size + 1}") # 计算准确率 accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0 print(f"Accuracy: {accuracy * 100:.2f}%") if accuracy > threshold: return True return False # 使用示例 image_dir = 'blackbox_models/vgg16/trigger/images/2' model_path = 'blackbox_models/vgg16/vgg16.onnx' target_class = 2 # 替换为您要检查的目标分类 batch_predict_images(model_path, image_dir, target_class)