Ver código fonte

修改模型加载权重预测代码

liyan 1 ano atrás
pai
commit
a89bc2c570
1 arquivos alterados com 73 adições e 40 exclusões
  1. 73 40
      predict_pt.py

+ 73 - 40
predict_pt.py

@@ -1,21 +1,21 @@
 import os
-import cv2
 import time
+
+import numpy as np
 import torch
 import argparse
-import albumentations
-from model.layer import deploy
+from PIL import Image
+from torchvision import transforms
 
 # -------------------------------------------------------------------------------------------------------------------- #
 parser = argparse.ArgumentParser(description='|pt模型推理|')
-parser.add_argument('--model_path', default='best.pt', type=str, help='|pt模型位置|')
-parser.add_argument('--data_path', default='/home/yhsun/classification-main/dataset/CIFAR-10/train_cifar10_JPG/airplane', type=str, help='|图片文件夹位置|')
+parser.add_argument('--model_path', default='./checkpoints/Alexnet/best.pt', type=str, help='|pt模型位置|')
+parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
 parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
-parser.add_argument('--normalization', default='sigmoid', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
-parser.add_argument('--batch', default=1, type=int, help='|输入图片批量|')
+parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
 parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
-parser.add_argument('--float16', default=True, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
+parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
 args, _ = parser.parse_known_args()  # 防止传入参数冲突,替代args = parser.parse_args()
 # -------------------------------------------------------------------------------------------------------------------- #
 assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
@@ -29,47 +29,80 @@ def predict_pt(args):
     # 加载模型
     model_dict = torch.load(args.model_path, map_location='cpu')
     model = model_dict['model']
-    model = deploy(model, args.normalization)
     model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
     epoch = model_dict['epoch_finished']
-    m_ap = round(model_dict['standard'], 4)
-    print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | m_ap:{m_ap}|')
+    accuracy = round(model_dict['standard'], 4)
+    print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | accuracy:{accuracy} |')
     # 推理
-    image_dir = sorted(os.listdir(args.data_path))
     start_time = time.time()
     with torch.no_grad():
-        dataloader = torch.utils.data.DataLoader(torch_dataset(image_dir), batch_size=args.batch,
+        print(f"加载测试集至内存...")
+        transform = transforms.Compose([
+            transforms.ToTensor(),  # 将图像转换为PyTorch张量
+            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
+        ])
+        dataset = CustomDataset(data_dir=args.data_path, transform=transform)
+        dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
                                                  shuffle=False, drop_last=False, pin_memory=False,
                                                  num_workers=args.num_worker)
-        result = []
-        for item, batch in enumerate(dataloader):
-            batch = batch.to(args.device)
-            pred_batch = model(batch).detach().cpu()
-            result.extend(pred_batch.tolist())
-        for i in range(len(result)):
-            result[i] = [round(result[i][_], 2) for _ in range(len(result[i]))]
-            print(f'| {image_dir[i]}:{result[i]} |')
-    end_time = time.time()
-    print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_dir), args.batch, (end_time - start_time) / len(image_dir)))
-
-
-class torch_dataset(torch.utils.data.Dataset):
-    def __init__(self, image_dir):
-        self.image_dir = image_dir
-        self.transform = albumentations.Compose([
-            albumentations.LongestMaxSize(args.input_size),
-            albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
-                                       border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
+        print(f"加载测试集完成,开始预测...")
+        correct = 0
+        total = 0
+        epoch = 0
+        for index, (image_batch, true_batch) in enumerate(dataloader):
+            image_batch = image_batch.to(args.device)
+            pred_batch = model(image_batch).detach().cpu()
+            # 获取指标项
+            _, predicted = torch.max(pred_batch, 1)
+            total += true_batch.size(0)
+            correct += (predicted == true_batch).sum().item()
+            epoch = epoch + 1
+        # 计算指标
+        accuracy = correct / total
+        end_time = time.time()
+        print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
+
+
+class CustomDataset(torch.utils.data.Dataset):
+    """
+    自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别
+    """
+
+    def __init__(self, data_dir, image_size=(32, 32), transform=None):
+        self.data_dir = data_dir
+        self.image_size = image_size
+        self.transform = transform
+
+        self.images = []
+        self.labels = []
+
+        # 遍历指定目录下的子目录,每个子目录代表一个类别
+        class_dirs = sorted(os.listdir(data_dir))
+        for index, class_dir in enumerate(class_dirs):
+            class_path = os.path.join(data_dir, class_dir)
+
+            # 遍历当前类别目录下的图像文件
+            for image_file in os.listdir(class_path):
+                image_path = os.path.join(class_path, image_file)
+
+                # 使用PIL加载图像并调整大小
+                image = Image.open(image_path).convert('RGB')
+                image = image.resize(image_size)
+
+                self.images.append(np.array(image))
+                self.labels.append(index)
 
     def __len__(self):
-        return len(self.image_dir)
-
-    def __getitem__(self, index):
-        image = cv2.imread(args.data_path + '/' + self.image_dir[index])  # 读取图片
-        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转为RGB通道
-        image = self.transform(image=image)['image']  # 缩放和填充图片(归一化、调维度在模型中完成)
-        image = torch.tensor(image, dtype=torch.float16 if args.float16 else torch.float32)
-        return image
+        return len(self.images)
+
+    def __getitem__(self, idx):
+        image = self.images[idx]
+        label = self.labels[idx]
+
+        if self.transform:
+            image = self.transform(Image.fromarray(image))
+
+        return image, label
 
 
 if __name__ == '__main__':