predict_pt.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. import time
  3. import numpy as np
  4. import torch
  5. import argparse
  6. from PIL import Image
  7. from torchvision import transforms
  8. # -------------------------------------------------------------------------------------------------------------------- #
  9. parser = argparse.ArgumentParser(description='|pt模型推理|')
  10. parser.add_argument('--model_path', default='./checkpoints/Alexnet/best.pt', type=str, help='|pt模型位置|')
  11. parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
  12. parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
  13. parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
  14. parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
  15. parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
  16. parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
  17. args, _ = parser.parse_known_args() # 防止传入参数冲突,替代args = parser.parse_args()
  18. # -------------------------------------------------------------------------------------------------------------------- #
  19. assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
  20. assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
  21. if args.float16:
  22. assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
  23. # -------------------------------------------------------------------------------------------------------------------- #
  24. def predict_pt(args):
  25. # 加载模型
  26. model_dict = torch.load(args.model_path, map_location='cpu')
  27. model = model_dict['model']
  28. model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
  29. epoch = model_dict['epoch_finished']
  30. accuracy = round(model_dict['standard'], 4)
  31. print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | accuracy:{accuracy} |')
  32. # 推理
  33. start_time = time.time()
  34. with torch.no_grad():
  35. print(f"加载测试集至内存...")
  36. transform = transforms.Compose([
  37. transforms.ToTensor(), # 将图像转换为PyTorch张量
  38. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  39. ])
  40. dataset = CustomDataset(data_dir=args.data_path, transform=transform)
  41. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
  42. shuffle=False, drop_last=False, pin_memory=False,
  43. num_workers=args.num_worker)
  44. print(f"加载测试集完成,开始预测...")
  45. correct = 0
  46. total = 0
  47. epoch = 0
  48. for index, (image_batch, true_batch) in enumerate(dataloader):
  49. image_batch = image_batch.to(args.device)
  50. pred_batch = model(image_batch).detach().cpu()
  51. # 获取指标项
  52. _, predicted = torch.max(pred_batch, 1)
  53. total += true_batch.size(0)
  54. correct += (predicted == true_batch).sum().item()
  55. epoch = epoch + 1
  56. # 计算指标
  57. accuracy = correct / total
  58. end_time = time.time()
  59. print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
  60. class CustomDataset(torch.utils.data.Dataset):
  61. """
  62. 自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别
  63. """
  64. def __init__(self, data_dir, image_size=(32, 32), transform=None):
  65. self.data_dir = data_dir
  66. self.image_size = image_size
  67. self.transform = transform
  68. self.images = []
  69. self.labels = []
  70. # 遍历指定目录下的子目录,每个子目录代表一个类别
  71. class_dirs = sorted(os.listdir(data_dir))
  72. for index, class_dir in enumerate(class_dirs):
  73. class_path = os.path.join(data_dir, class_dir)
  74. # 遍历当前类别目录下的图像文件
  75. for image_file in os.listdir(class_path):
  76. image_path = os.path.join(class_path, image_file)
  77. # 使用PIL加载图像并调整大小
  78. image = Image.open(image_path).convert('RGB')
  79. image = image.resize(image_size)
  80. self.images.append(np.array(image))
  81. self.labels.append(index)
  82. def __len__(self):
  83. return len(self.images)
  84. def __getitem__(self, idx):
  85. image = self.images[idx]
  86. label = self.labels[idx]
  87. if self.transform:
  88. image = self.transform(Image.fromarray(image))
  89. return image, label
  90. if __name__ == '__main__':
  91. predict_pt(args)