predict.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import cv2
  3. import time
  4. import torch
  5. import argparse
  6. import albumentations
  7. from model.layer import deploy
  8. """
  9. 模型训练验证代码
  10. """
  11. # -------------------------------------------------------------------------------------------------------------------- #
  12. parser = argparse.ArgumentParser(description='|pt模型推理|')
  13. parser.add_argument('--model_path', default='best.pt', type=str, help='|pt模型位置|')
  14. parser.add_argument('--data_path', default='./dataset/CIFAR-10/test', type=str, help='|图片文件夹位置|')
  15. parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
  16. parser.add_argument('--normalization', default='softmax', type=str, help='|选择sigmoid或softmax归一化,单类别一定要选sigmoid|')
  17. parser.add_argument('--batch', default=1, type=int, help='|输入图片批量|')
  18. parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
  19. parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
  20. parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
  21. args, _ = parser.parse_known_args() # 防止传入参数冲突,替代args = parser.parse_args()
  22. # -------------------------------------------------------------------------------------------------------------------- #
  23. assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
  24. assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
  25. if args.float16:
  26. assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
  27. # -------------------------------------------------------------------------------------------------------------------- #
  28. def predict_pt(args):
  29. # 加载模型
  30. model_dict = torch.load(args.model_path, map_location='cpu')
  31. model = model_dict['model']
  32. model = deploy(model, args.normalization)
  33. model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
  34. epoch = model_dict['epoch_finished']
  35. # m_ap = round(model_dict['standard'], 4)
  36. # print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | m_ap:{m_ap}|')
  37. print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} |')
  38. # 推理
  39. image_dir = sorted(os.listdir(args.data_path))
  40. start_time = time.time()
  41. with torch.no_grad():
  42. dataloader = torch.utils.data.DataLoader(torch_dataset(image_dir), batch_size=args.batch,
  43. shuffle=False, drop_last=False, pin_memory=False,
  44. num_workers=args.num_worker)
  45. result = []
  46. for item, batch in enumerate(dataloader):
  47. batch = batch.to(args.device)
  48. pred_batch = model(batch).detach().cpu()
  49. result.extend(pred_batch.tolist())
  50. for i in range(len(result)):
  51. result[i] = [round(result[i][_], 2) for _ in range(len(result[i]))]
  52. print(f'| {image_dir[i]}:{result[i]} |')
  53. end_time = time.time()
  54. print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_dir), args.batch, (end_time - start_time) / len(image_dir)))
  55. class torch_dataset(torch.utils.data.Dataset):
  56. def __init__(self, image_dir):
  57. self.image_dir = image_dir
  58. self.transform = albumentations.Compose([
  59. albumentations.LongestMaxSize(args.input_size),
  60. albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
  61. border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
  62. def __len__(self):
  63. return len(self.image_dir)
  64. def __getitem__(self, index):
  65. image = cv2.imread(args.data_path + '/' + self.image_dir[index]) # 读取图片
  66. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转为RGB通道
  67. image = self.transform(image=image)['image'] # 缩放和填充图片(归一化、调维度在模型中完成)
  68. image = torch.tensor(image, dtype=torch.float16 if args.float16 else torch.float32)
  69. return image
  70. if __name__ == '__main__':
  71. predict_pt(args)