detect_embed.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from numpy import random
  8. from torch import nn
  9. from watermark_codec import ModelDecoder
  10. from models.experimental import attempt_load
  11. from utils import secret_util
  12. from utils.datasets import LoadStreams, LoadImages
  13. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  14. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
  15. from utils.plots import plot_one_box
  16. from utils.torch_utils import select_device, load_classifier, time_synchronized
  17. def detect(save_img=False):
  18. source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  19. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  20. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  21. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  22. # Directories
  23. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  24. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  25. # Initialize
  26. set_logging()
  27. device = select_device(opt.device)
  28. half = device.type != 'cpu' # half precision only supported on CUDA
  29. # Load model
  30. model = attempt_load(weights, map_location=device) # load FP32 model
  31. stride = int(model.stride.max()) # model stride
  32. imgsz = check_img_size(imgsz, s=stride) # check img_size
  33. if half:
  34. model.half() # to FP16
  35. # watermark extract
  36. conv_list = []
  37. for module in model.modules():
  38. if isinstance(module, nn.Conv2d):
  39. conv_list.append(module)
  40. conv_list = conv_list[0:2]
  41. decoder = ModelDecoder(layers=conv_list, key_path=args.key_path, device=args.device) # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
  42. secret_extract = decoder.decode() # 提取密码标签
  43. result = secret_util.verify_secret(secret_extract)
  44. print(f"白盒水印验证结果: {result}, 提取的密码标签为: {secret_extract}")
  45. # Second-stage classifier
  46. classify = False
  47. if classify:
  48. modelc = load_classifier(name='resnet101', n=2) # initialize
  49. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  50. # Set Dataloader
  51. vid_path, vid_writer = None, None
  52. if webcam:
  53. view_img = check_imshow()
  54. cudnn.benchmark = True # set True to speed up constant image size inference
  55. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
  56. else:
  57. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  58. # Get names and colors
  59. names = model.module.names if hasattr(model, 'module') else model.names
  60. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  61. # Run inference
  62. if device.type != 'cpu':
  63. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  64. t0 = time.time()
  65. for path, img, im0s, vid_cap in dataset:
  66. img = torch.from_numpy(img).to(device)
  67. img = img.half() if half else img.float() # uint8 to fp16/32
  68. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  69. if img.ndimension() == 3:
  70. img = img.unsqueeze(0)
  71. # Inference
  72. t1 = time_synchronized()
  73. pred = model(img, augment=opt.augment)[0]
  74. # Apply NMS
  75. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  76. t2 = time_synchronized()
  77. # Apply Classifier
  78. if classify:
  79. pred = apply_classifier(pred, modelc, img, im0s)
  80. # Process detections
  81. for i, det in enumerate(pred): # detections per image
  82. if webcam: # batch_size >= 1
  83. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
  84. else:
  85. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
  86. p = Path(p) # to Path
  87. save_path = str(save_dir / p.name) # img.jpg
  88. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  89. s += '%gx%g ' % img.shape[2:] # print string
  90. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  91. if len(det):
  92. # Rescale boxes from img_size to im0 size
  93. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  94. # Print results
  95. for c in det[:, -1].unique():
  96. n = (det[:, -1] == c).sum() # detections per class
  97. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  98. # Write results
  99. for *xyxy, conf, cls in reversed(det):
  100. if save_txt: # Write to file
  101. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  102. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  103. with open(txt_path + '.txt', 'a') as f:
  104. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  105. if save_img or view_img: # Add bbox to image
  106. label = f'{names[int(cls)]} {conf:.2f}'
  107. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  108. # Print time (inference + NMS)
  109. print(f'{s}Done. ({t2 - t1:.3f}s)')
  110. # Stream results
  111. if view_img:
  112. cv2.imshow(str(p), im0)
  113. cv2.waitKey(1) # 1 millisecond
  114. # Save results (image with detections)
  115. if save_img:
  116. if dataset.mode == 'image':
  117. cv2.imwrite(save_path, im0)
  118. else: # 'video' or 'stream'
  119. if vid_path != save_path: # new video
  120. vid_path = save_path
  121. if isinstance(vid_writer, cv2.VideoWriter):
  122. vid_writer.release() # release previous video writer
  123. if vid_cap: # video
  124. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  125. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  126. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  127. else: # stream
  128. fps, w, h = 30, im0.shape[1], im0.shape[0]
  129. save_path += '.mp4'
  130. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  131. vid_writer.write(im0)
  132. if save_txt or save_img:
  133. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  134. print(f"Results saved to {save_dir}{s}")
  135. print(f'Done. ({time.time() - t0:.3f}s)')
  136. if __name__ == '__main__':
  137. parser = argparse.ArgumentParser()
  138. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  139. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  140. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  141. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  142. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  143. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  144. parser.add_argument('--view-img', action='store_true', help='display results')
  145. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  146. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  147. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  148. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  149. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  150. parser.add_argument('--augment', action='store_true', help='augmented inference')
  151. parser.add_argument('--update', action='store_true', help='update all models')
  152. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  153. parser.add_argument('--name', default='exp', help='save results to project/name')
  154. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  155. opt = parser.parse_args()
  156. print(opt)
  157. check_requirements(exclude=('pycocotools', 'thop'))
  158. with torch.no_grad():
  159. if opt.update: # update all models (to fix SourceChangeWarning)
  160. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  161. detect()
  162. strip_optimizer(opt.weights)
  163. else:
  164. detect()