detect_embed.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from numpy import random
  8. from torch import nn
  9. from watermark_codec import ModelDecoder
  10. from models.experimental import attempt_load
  11. from utils import secret_util
  12. from utils.datasets import LoadStreams, LoadImages
  13. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  14. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
  15. from utils.plots import plot_one_box
  16. from utils.torch_utils import select_device, load_classifier, time_synchronized
  17. def detect(save_img=False):
  18. source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  19. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  20. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  21. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  22. # Directories
  23. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  24. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  25. # Initialize
  26. set_logging()
  27. device = select_device(opt.device)
  28. half = device.type != 'cpu' # half precision only supported on CUDA
  29. # Load model
  30. model = attempt_load(weights, map_location=device) # load FP32 model
  31. # watermark extract
  32. ckpt = torch.load(weights, map_location=device)
  33. conv_list = ckpt['layers']
  34. decoder = ModelDecoder(layers=conv_list, key_path=opt.key_path, device=device) # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
  35. secret_extract = decoder.decode() # 提取密码标签
  36. result = secret_util.verify_secret(secret_extract)
  37. print(f"白盒水印验证结果: {result}, 提取的密码标签为: {secret_extract}")
  38. stride = int(model.stride.max()) # model stride
  39. imgsz = check_img_size(imgsz, s=stride) # check img_size
  40. if half:
  41. model.half() # to FP16
  42. # Second-stage classifier
  43. classify = False
  44. if classify:
  45. modelc = load_classifier(name='resnet101', n=2) # initialize
  46. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  47. # Set Dataloader
  48. vid_path, vid_writer = None, None
  49. if webcam:
  50. view_img = check_imshow()
  51. cudnn.benchmark = True # set True to speed up constant image size inference
  52. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
  53. else:
  54. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  55. # Get names and colors
  56. names = model.module.names if hasattr(model, 'module') else model.names
  57. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  58. # Run inference
  59. if device.type != 'cpu':
  60. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  61. t0 = time.time()
  62. for path, img, im0s, vid_cap in dataset:
  63. img = torch.from_numpy(img).to(device)
  64. img = img.half() if half else img.float() # uint8 to fp16/32
  65. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  66. if img.ndimension() == 3:
  67. img = img.unsqueeze(0)
  68. # Inference
  69. t1 = time_synchronized()
  70. pred = model(img, augment=opt.augment)[0]
  71. # Apply NMS
  72. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  73. t2 = time_synchronized()
  74. # Apply Classifier
  75. if classify:
  76. pred = apply_classifier(pred, modelc, img, im0s)
  77. # Process detections
  78. for i, det in enumerate(pred): # detections per image
  79. if webcam: # batch_size >= 1
  80. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
  81. else:
  82. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
  83. p = Path(p) # to Path
  84. save_path = str(save_dir / p.name) # img.jpg
  85. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  86. s += '%gx%g ' % img.shape[2:] # print string
  87. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  88. if len(det):
  89. # Rescale boxes from img_size to im0 size
  90. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  91. # Print results
  92. for c in det[:, -1].unique():
  93. n = (det[:, -1] == c).sum() # detections per class
  94. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  95. # Write results
  96. for *xyxy, conf, cls in reversed(det):
  97. if save_txt: # Write to file
  98. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  99. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  100. with open(txt_path + '.txt', 'a') as f:
  101. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  102. if save_img or view_img: # Add bbox to image
  103. label = f'{names[int(cls)]} {conf:.2f}'
  104. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  105. # Print time (inference + NMS)
  106. print(f'{s}Done. ({t2 - t1:.3f}s)')
  107. # Stream results
  108. if view_img:
  109. cv2.imshow(str(p), im0)
  110. cv2.waitKey(1) # 1 millisecond
  111. # Save results (image with detections)
  112. if save_img:
  113. if dataset.mode == 'image':
  114. cv2.imwrite(save_path, im0)
  115. else: # 'video' or 'stream'
  116. if vid_path != save_path: # new video
  117. vid_path = save_path
  118. if isinstance(vid_writer, cv2.VideoWriter):
  119. vid_writer.release() # release previous video writer
  120. if vid_cap: # video
  121. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  122. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  123. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  124. else: # stream
  125. fps, w, h = 30, im0.shape[1], im0.shape[0]
  126. save_path += '.mp4'
  127. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  128. vid_writer.write(im0)
  129. if save_txt or save_img:
  130. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  131. print(f"Results saved to {save_dir}{s}")
  132. print(f'Done. ({time.time() - t0:.3f}s)')
  133. if __name__ == '__main__':
  134. parser = argparse.ArgumentParser()
  135. parser.add_argument('--weights', nargs='+', type=str, default='runs/train_whitebox_wm/exp5/weights/last.pt', help='model.pt path(s)')
  136. parser.add_argument('--key_path', type=str, default='runs/train_whitebox_wm/exp5/key.pt', help='white box watermark key path')
  137. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  138. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  139. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  140. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  141. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  142. parser.add_argument('--view-img', action='store_true', help='display results')
  143. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  144. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  145. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  146. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  147. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  148. parser.add_argument('--augment', action='store_true', help='augmented inference')
  149. parser.add_argument('--update', action='store_true', help='update all models')
  150. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  151. parser.add_argument('--name', default='exp', help='save results to project/name')
  152. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  153. opt = parser.parse_args()
  154. print(opt)
  155. check_requirements(exclude=('pycocotools', 'thop'))
  156. with torch.no_grad():
  157. if opt.update: # update all models (to fix SourceChangeWarning)
  158. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  159. detect()
  160. strip_optimizer(opt.weights)
  161. else:
  162. detect()