detect.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from numpy import random
  8. from models.experimental import attempt_load
  9. from utils.datasets import LoadStreams, LoadImages
  10. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  11. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
  12. from utils.plots import plot_one_box
  13. from utils.torch_utils import select_device, load_classifier, time_synchronized
  14. def detect(save_img=False):
  15. source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  16. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  17. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  18. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  19. # Directories
  20. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  21. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  22. # Initialize
  23. set_logging()
  24. device = select_device(opt.device)
  25. half = device.type != 'cpu' # half precision only supported on CUDA
  26. # Load model
  27. model = attempt_load(weights, map_location=device) # load FP32 model
  28. stride = int(model.stride.max()) # model stride
  29. imgsz = check_img_size(imgsz, s=stride) # check img_size
  30. if half:
  31. model.half() # to FP16
  32. # Second-stage classifier
  33. classify = False
  34. if classify:
  35. modelc = load_classifier(name='resnet101', n=2) # initialize
  36. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  37. # Set Dataloader
  38. vid_path, vid_writer = None, None
  39. if webcam:
  40. view_img = check_imshow()
  41. cudnn.benchmark = True # set True to speed up constant image size inference
  42. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
  43. else:
  44. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  45. # Get names and colors
  46. names = model.module.names if hasattr(model, 'module') else model.names
  47. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  48. # Run inference
  49. if device.type != 'cpu':
  50. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  51. t0 = time.time()
  52. for path, img, im0s, vid_cap in dataset:
  53. img = torch.from_numpy(img).to(device)
  54. img = img.half() if half else img.float() # uint8 to fp16/32
  55. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  56. if img.ndimension() == 3:
  57. img = img.unsqueeze(0)
  58. # Inference
  59. t1 = time_synchronized()
  60. pred = model(img, augment=opt.augment)[0]
  61. # Apply NMS
  62. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  63. t2 = time_synchronized()
  64. # Apply Classifier
  65. if classify:
  66. pred = apply_classifier(pred, modelc, img, im0s)
  67. # Process detections
  68. for i, det in enumerate(pred): # detections per image
  69. if webcam: # batch_size >= 1
  70. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
  71. else:
  72. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
  73. p = Path(p) # to Path
  74. save_path = str(save_dir / p.name) # img.jpg
  75. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  76. s += '%gx%g ' % img.shape[2:] # print string
  77. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  78. if len(det):
  79. # Rescale boxes from img_size to im0 size
  80. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  81. # Print results
  82. for c in det[:, -1].unique():
  83. n = (det[:, -1] == c).sum() # detections per class
  84. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  85. # Write results
  86. for *xyxy, conf, cls in reversed(det):
  87. if save_txt: # Write to file
  88. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  89. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  90. with open(txt_path + '.txt', 'a') as f:
  91. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  92. if save_img or view_img: # Add bbox to image
  93. label = f'{names[int(cls)]} {conf:.2f}'
  94. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  95. # Print time (inference + NMS)
  96. print(f'{s}Done. ({t2 - t1:.3f}s)')
  97. # Stream results
  98. if view_img:
  99. cv2.imshow(str(p), im0)
  100. cv2.waitKey(1) # 1 millisecond
  101. # Save results (image with detections)
  102. if save_img:
  103. if dataset.mode == 'image':
  104. cv2.imwrite(save_path, im0)
  105. else: # 'video' or 'stream'
  106. if vid_path != save_path: # new video
  107. vid_path = save_path
  108. if isinstance(vid_writer, cv2.VideoWriter):
  109. vid_writer.release() # release previous video writer
  110. if vid_cap: # video
  111. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  112. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  113. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  114. else: # stream
  115. fps, w, h = 30, im0.shape[1], im0.shape[0]
  116. save_path += '.mp4'
  117. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  118. vid_writer.write(im0)
  119. if save_txt or save_img:
  120. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  121. print(f"Results saved to {save_dir}{s}")
  122. print(f'Done. ({time.time() - t0:.3f}s)')
  123. if __name__ == '__main__':
  124. parser = argparse.ArgumentParser()
  125. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  126. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  127. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  128. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  129. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  130. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  131. parser.add_argument('--view-img', action='store_true', help='display results')
  132. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  133. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  134. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  135. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  136. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  137. parser.add_argument('--augment', action='store_true', help='augmented inference')
  138. parser.add_argument('--update', action='store_true', help='update all models')
  139. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  140. parser.add_argument('--name', default='exp', help='save results to project/name')
  141. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  142. opt = parser.parse_args()
  143. print(opt)
  144. check_requirements(exclude=('pycocotools', 'thop'))
  145. with torch.no_grad():
  146. if opt.update: # update all models (to fix SourceChangeWarning)
  147. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  148. detect()
  149. strip_optimizer(opt.weights)
  150. else:
  151. detect()