predict_pt.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import os
  2. import cv2
  3. import time
  4. import torch
  5. import argparse
  6. import torchvision
  7. import numpy as np
  8. import albumentations
  9. from model.layer import deploy
  10. # -------------------------------------------------------------------------------------------------------------------- #
  11. parser = argparse.ArgumentParser(description='|pt模型推理|')
  12. parser.add_argument('--model_path', default=r'D:\桌面\ObjectDetection-main\last.pt', type=str, help='|pt模型位置|')
  13. parser.add_argument('--image_path', default=r'D:\桌面\ObjectDetection-main\datasets\coco_wm\images\test2017_wm', type=str, help='|图片文件夹位置|')
  14. parser.add_argument('--input_size', default=640, type=int, help='|模型输入图片大小|')
  15. parser.add_argument('--batch', default=1, type=int, help='|输入图片批量|')
  16. parser.add_argument('--confidence_threshold', default=0.35, type=float, help='|置信筛选度阈值(>阈值留下)|')
  17. parser.add_argument('--iou_threshold', default=0.65, type=float, help='|iou阈值筛选阈值(<阈值留下)|')
  18. parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
  19. parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
  20. parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
  21. args, _ = parser.parse_known_args() # 防止传入参数冲突,替代args = parser.parse_args()
  22. args.model_path = args.model_path.split('.')[0] + '.pt'
  23. # -------------------------------------------------------------------------------------------------------------------- #
  24. assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
  25. # assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
  26. if args.float16:
  27. assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
  28. # -------------------------------------------------------------------------------------------------------------------- #
  29. def confidence_screen(pred, confidence_threshold):
  30. result = []
  31. for i in range(len(pred)): # 对一张图片的每个输出层分别进行操作
  32. judge = torch.where(pred[i][..., 4] > confidence_threshold, True, False)
  33. result.append((pred[i][judge]))
  34. result = torch.concat(result, dim=0)
  35. if result.shape[0] == 0:
  36. return result
  37. index = torch.argsort(result[:, 4], dim=0, descending=True)
  38. result = result[index]
  39. return result
  40. def iou_single(A, B): # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
  41. x1 = torch.maximum(A[:, 0], B[0])
  42. y1 = torch.maximum(A[:, 1], B[1])
  43. x2 = torch.minimum(A[:, 0] + A[:, 2], B[0] + B[2])
  44. y2 = torch.minimum(A[:, 1] + A[:, 3], B[1] + B[3])
  45. zeros = torch.zeros(1, device=A.device)
  46. intersection = torch.maximum(x2 - x1, zeros) * torch.maximum(y2 - y1, zeros)
  47. union = A[:, 2] * A[:, 3] + B[2] * B[3] - intersection
  48. return intersection / union
  49. def nms(pred, iou_threshold): # 输入为(batch,(x_min,y_min,w,h))相对/真实坐标
  50. pred[:, 2:4] = pred[:, 0:2] + pred[:, 2:4] # (x_min,y_min,x_max,y_max)真实坐标
  51. index = torchvision.ops.nms(pred[:, 0:4], pred[:, 4], 1 - iou_threshold)[:100] # 非极大值抑制,最多100
  52. pred = pred[index]
  53. pred[:, 2:4] = pred[:, 2:4] - pred[:, 0:2] # (x_min,y_min,w,h)真实坐标
  54. return pred
  55. def draw(image, frame, cls, name): # 输入(x_min,y_min,w,h)真实坐标
  56. image = image.astype(np.uint8)
  57. for i in range(len(frame)):
  58. a = (int(frame[i][0]), int(frame[i][1]))
  59. b = (int(frame[i][0] + frame[i][2]), int(frame[i][1] + frame[i][3]))
  60. cv2.rectangle(image, a, b, color=(0, 255, 0), thickness=2)
  61. cv2.putText(image, 'class:' + str(cls[i]), a, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
  62. cv2.imwrite('save_' + name, image)
  63. print(f'| {name}: save_{name} |')
  64. def predict_pt(args):
  65. # 加载模型
  66. model_dict = torch.load(args.model_path, map_location='cpu')
  67. model = model_dict['model']
  68. model = deploy(model, args.input_size)
  69. model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
  70. epoch = model_dict['epoch_finished']
  71. m_ap = round(model_dict['standard'], 4)
  72. print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | m_ap:{m_ap}|')
  73. # 推理
  74. image_dir = sorted(os.listdir(args.image_path))
  75. start_time = time.time()
  76. with torch.no_grad():
  77. dataloader = torch.utils.data.DataLoader(torch_dataset(image_dir), batch_size=args.batch, shuffle=False,
  78. drop_last=False, pin_memory=False, num_workers=args.num_worker)
  79. for item, (image_batch, name_batch) in enumerate(dataloader):
  80. image_all = image_batch.cpu().numpy().astype(np.uint8) # 转为numpy,用于画图
  81. image_batch = image_batch.to(args.device)
  82. pred_batch = model(image_batch)
  83. # 对batch中的每张图片分别操作
  84. for i in range(pred_batch[0].shape[0]):
  85. pred = [_[i] for _ in pred_batch] # (Cx,Cy,w,h)
  86. pred = confidence_screen(pred, args.confidence_threshold) # 置信度筛选
  87. if pred.shape[0] == 0:
  88. print(f'{name_batch[i]}:None')
  89. continue
  90. pred[:, 0:2] = pred[:, 0:2] - pred[:, 2:4] / 2 # (x_min,y_min,w,h)真实坐标
  91. pred = nms(pred, args.iou_threshold) # 非极大值抑制
  92. frame = pred[:, 0:4] # 边框
  93. cls = torch.argmax(pred[:, 5:], dim=1) # 类别
  94. draw(image_all[i], frame.cpu().numpy(), cls.cpu().numpy(), name_batch[i])
  95. end_time = time.time()
  96. print('| 数据:{} 批量:{} 每张耗时:{:.4f} |'.format(len(image_dir), args.batch, (end_time - start_time) / len(image_dir)))
  97. class torch_dataset(torch.utils.data.Dataset):
  98. def __init__(self, image_dir):
  99. self.image_dir = image_dir
  100. self.transform = albumentations.Compose([
  101. albumentations.LongestMaxSize(args.input_size),
  102. albumentations.PadIfNeeded(min_height=args.input_size, min_width=args.input_size,
  103. border_mode=cv2.BORDER_CONSTANT, value=(128, 128, 128))])
  104. def __len__(self):
  105. return len(self.image_dir)
  106. def __getitem__(self, index):
  107. image = cv2.imread(args.image_path + '/' + self.image_dir[index]) # 读取图片
  108. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转为RGB通道
  109. image = self.transform(image=image)['image'] # 缩放和填充图片(归一化、调维度等在模型中完成)
  110. image = torch.tensor(image, dtype=torch.float16 if args.float16 else torch.float32)
  111. name = self.image_dir[index]
  112. return image, name
  113. if __name__ == '__main__':
  114. predict_pt(args)