onnx_inference.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. #!/usr/bin/env python3
  2. # Copyright (c) Megvii, Inc. and its affiliates.
  3. import os
  4. import cv2
  5. import numpy as np
  6. import onnxruntime
  7. COCO_CLASSES = (
  8. "person",
  9. "bicycle",
  10. "car",
  11. "motorcycle",
  12. "airplane",
  13. "bus",
  14. "train",
  15. "truck",
  16. "boat",
  17. "traffic light",
  18. "fire hydrant",
  19. "stop sign",
  20. "parking meter",
  21. "bench",
  22. "bird",
  23. "cat",
  24. "dog",
  25. "horse",
  26. "sheep",
  27. "cow",
  28. "elephant",
  29. "bear",
  30. "zebra",
  31. "giraffe",
  32. "backpack",
  33. "umbrella",
  34. "handbag",
  35. "tie",
  36. "suitcase",
  37. "frisbee",
  38. "skis",
  39. "snowboard",
  40. "sports ball",
  41. "kite",
  42. "baseball bat",
  43. "baseball glove",
  44. "skateboard",
  45. "surfboard",
  46. "tennis racket",
  47. "bottle",
  48. "wine glass",
  49. "cup",
  50. "fork",
  51. "knife",
  52. "spoon",
  53. "bowl",
  54. "banana",
  55. "apple",
  56. "sandwich",
  57. "orange",
  58. "broccoli",
  59. "carrot",
  60. "hot dog",
  61. "pizza",
  62. "donut",
  63. "cake",
  64. "chair",
  65. "couch",
  66. "potted plant",
  67. "bed",
  68. "dining table",
  69. "toilet",
  70. "tv",
  71. "laptop",
  72. "mouse",
  73. "remote",
  74. "keyboard",
  75. "cell phone",
  76. "microwave",
  77. "oven",
  78. "toaster",
  79. "sink",
  80. "refrigerator",
  81. "book",
  82. "clock",
  83. "vase",
  84. "scissors",
  85. "teddy bear",
  86. "hair drier",
  87. "toothbrush",
  88. )
  89. _COLORS = np.array(
  90. [
  91. 0.000, 0.447, 0.741,
  92. 0.850, 0.325, 0.098,
  93. 0.929, 0.694, 0.125,
  94. 0.494, 0.184, 0.556,
  95. 0.466, 0.674, 0.188,
  96. 0.301, 0.745, 0.933,
  97. 0.635, 0.078, 0.184,
  98. 0.300, 0.300, 0.300,
  99. 0.600, 0.600, 0.600,
  100. 1.000, 0.000, 0.000,
  101. 1.000, 0.500, 0.000,
  102. 0.749, 0.749, 0.000,
  103. 0.000, 1.000, 0.000,
  104. 0.000, 0.000, 1.000,
  105. 0.667, 0.000, 1.000,
  106. 0.333, 0.333, 0.000,
  107. 0.333, 0.667, 0.000,
  108. 0.333, 1.000, 0.000,
  109. 0.667, 0.333, 0.000,
  110. 0.667, 0.667, 0.000,
  111. 0.667, 1.000, 0.000,
  112. 1.000, 0.333, 0.000,
  113. 1.000, 0.667, 0.000,
  114. 1.000, 1.000, 0.000,
  115. 0.000, 0.333, 0.500,
  116. 0.000, 0.667, 0.500,
  117. 0.000, 1.000, 0.500,
  118. 0.333, 0.000, 0.500,
  119. 0.333, 0.333, 0.500,
  120. 0.333, 0.667, 0.500,
  121. 0.333, 1.000, 0.500,
  122. 0.667, 0.000, 0.500,
  123. 0.667, 0.333, 0.500,
  124. 0.667, 0.667, 0.500,
  125. 0.667, 1.000, 0.500,
  126. 1.000, 0.000, 0.500,
  127. 1.000, 0.333, 0.500,
  128. 1.000, 0.667, 0.500,
  129. 1.000, 1.000, 0.500,
  130. 0.000, 0.333, 1.000,
  131. 0.000, 0.667, 1.000,
  132. 0.000, 1.000, 1.000,
  133. 0.333, 0.000, 1.000,
  134. 0.333, 0.333, 1.000,
  135. 0.333, 0.667, 1.000,
  136. 0.333, 1.000, 1.000,
  137. 0.667, 0.000, 1.000,
  138. 0.667, 0.333, 1.000,
  139. 0.667, 0.667, 1.000,
  140. 0.667, 1.000, 1.000,
  141. 1.000, 0.000, 1.000,
  142. 1.000, 0.333, 1.000,
  143. 1.000, 0.667, 1.000,
  144. 0.333, 0.000, 0.000,
  145. 0.500, 0.000, 0.000,
  146. 0.667, 0.000, 0.000,
  147. 0.833, 0.000, 0.000,
  148. 1.000, 0.000, 0.000,
  149. 0.000, 0.167, 0.000,
  150. 0.000, 0.333, 0.000,
  151. 0.000, 0.500, 0.000,
  152. 0.000, 0.667, 0.000,
  153. 0.000, 0.833, 0.000,
  154. 0.000, 1.000, 0.000,
  155. 0.000, 0.000, 0.167,
  156. 0.000, 0.000, 0.333,
  157. 0.000, 0.000, 0.500,
  158. 0.000, 0.000, 0.667,
  159. 0.000, 0.000, 0.833,
  160. 0.000, 0.000, 1.000,
  161. 0.000, 0.000, 0.000,
  162. 0.143, 0.143, 0.143,
  163. 0.286, 0.286, 0.286,
  164. 0.429, 0.429, 0.429,
  165. 0.571, 0.571, 0.571,
  166. 0.714, 0.714, 0.714,
  167. 0.857, 0.857, 0.857,
  168. 0.000, 0.447, 0.741,
  169. 0.314, 0.717, 0.741,
  170. 0.50, 0.5, 0
  171. ]
  172. ).astype(np.float32).reshape(-1, 3)
  173. def preproc(img, input_size, swap=(2, 0, 1)):
  174. if len(img.shape) == 3:
  175. padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
  176. else:
  177. padded_img = np.ones(input_size, dtype=np.uint8) * 114
  178. r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
  179. resized_img = cv2.resize(
  180. img,
  181. (int(img.shape[1] * r), int(img.shape[0] * r)),
  182. interpolation=cv2.INTER_LINEAR,
  183. ).astype(np.uint8)
  184. padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
  185. padded_img = padded_img.transpose(swap)
  186. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
  187. return padded_img, r
  188. def nms(boxes, scores, nms_thr):
  189. """Single class NMS implemented in Numpy."""
  190. x1 = boxes[:, 0]
  191. y1 = boxes[:, 1]
  192. x2 = boxes[:, 2]
  193. y2 = boxes[:, 3]
  194. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  195. order = scores.argsort()[::-1]
  196. keep = []
  197. while order.size > 0:
  198. i = order[0]
  199. keep.append(i)
  200. xx1 = np.maximum(x1[i], x1[order[1:]])
  201. yy1 = np.maximum(y1[i], y1[order[1:]])
  202. xx2 = np.minimum(x2[i], x2[order[1:]])
  203. yy2 = np.minimum(y2[i], y2[order[1:]])
  204. w = np.maximum(0.0, xx2 - xx1 + 1)
  205. h = np.maximum(0.0, yy2 - yy1 + 1)
  206. inter = w * h
  207. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  208. inds = np.where(ovr <= nms_thr)[0]
  209. order = order[inds + 1]
  210. return keep
  211. def demo_postprocess(outputs, img_size, p6=False):
  212. grids = []
  213. expanded_strides = []
  214. strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
  215. hsizes = [img_size[0] // stride for stride in strides]
  216. wsizes = [img_size[1] // stride for stride in strides]
  217. for hsize, wsize, stride in zip(hsizes, wsizes, strides):
  218. xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
  219. grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
  220. grids.append(grid)
  221. shape = grid.shape[:2]
  222. expanded_strides.append(np.full((*shape, 1), stride))
  223. grids = np.concatenate(grids, 1)
  224. expanded_strides = np.concatenate(expanded_strides, 1)
  225. outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
  226. outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
  227. return outputs
  228. def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
  229. """Multiclass NMS implemented in Numpy. Class-agnostic version."""
  230. cls_inds = scores.argmax(1)
  231. cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
  232. valid_score_mask = cls_scores > score_thr
  233. if valid_score_mask.sum() == 0:
  234. return None
  235. valid_scores = cls_scores[valid_score_mask]
  236. valid_boxes = boxes[valid_score_mask]
  237. valid_cls_inds = cls_inds[valid_score_mask]
  238. keep = nms(valid_boxes, valid_scores, nms_thr)
  239. if keep:
  240. dets = np.concatenate(
  241. [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1
  242. )
  243. return dets
  244. def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
  245. """Multiclass NMS implemented in Numpy. Class-aware version."""
  246. final_dets = []
  247. num_classes = scores.shape[1]
  248. for cls_ind in range(num_classes):
  249. cls_scores = scores[:, cls_ind]
  250. valid_score_mask = cls_scores > score_thr
  251. if valid_score_mask.sum() == 0:
  252. continue
  253. else:
  254. valid_scores = cls_scores[valid_score_mask]
  255. valid_boxes = boxes[valid_score_mask]
  256. keep = nms(valid_boxes, valid_scores, nms_thr)
  257. if len(keep) > 0:
  258. cls_inds = np.ones((len(keep), 1)) * cls_ind
  259. dets = np.concatenate(
  260. [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
  261. )
  262. final_dets.append(dets)
  263. if len(final_dets) == 0:
  264. return None
  265. return np.concatenate(final_dets, 0)
  266. def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
  267. """Multiclass NMS implemented in Numpy"""
  268. if class_agnostic:
  269. nms_method = multiclass_nms_class_agnostic
  270. else:
  271. nms_method = multiclass_nms_class_aware
  272. return nms_method(boxes, scores, nms_thr, score_thr)
  273. def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
  274. for i in range(len(boxes)):
  275. box = boxes[i]
  276. cls_id = int(cls_ids[i])
  277. score = scores[i]
  278. if score < conf:
  279. continue
  280. x0 = int(box[0])
  281. y0 = int(box[1])
  282. x1 = int(box[2])
  283. y1 = int(box[3])
  284. color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
  285. text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
  286. txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
  287. font = cv2.FONT_HERSHEY_SIMPLEX
  288. txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
  289. cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
  290. txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
  291. cv2.rectangle(
  292. img,
  293. (x0, y0 + 1),
  294. (x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])),
  295. txt_bk_color,
  296. -1
  297. )
  298. cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
  299. return img
  300. def load_watermark_info(watermark_txt, img_width, img_height):
  301. watermark_boxes = {}
  302. with open(watermark_txt, 'r') as f:
  303. for line in f.readlines():
  304. parts = line.strip().split()
  305. filename = parts[0]
  306. filename = os.path.basename(filename)
  307. x_center, y_center, w, h = map(float, parts[1:5])
  308. cls = int(float(parts[5])) # 转换类别为整数
  309. # 计算绝对坐标
  310. x1 = (x_center - w / 2) * img_width
  311. y1 = (y_center - h / 2) * img_height
  312. x2 = (x_center + w / 2) * img_width
  313. y2 = (y_center + h / 2) * img_height
  314. if filename not in watermark_boxes:
  315. watermark_boxes[filename] = []
  316. watermark_boxes[filename].append([x1, y1, x2, y2, cls])
  317. return watermark_boxes
  318. def compute_ciou(box1, box2):
  319. """计算CIoU,假设box格式为[x1, y1, x2, y2]"""
  320. x1, y1, x2, y2 = box1
  321. x1g, y1g, x2g, y2g = box2
  322. # 求交集面积
  323. xi1, yi1 = max(x1, x1g), max(y1, y1g)
  324. xi2, yi2 = min(x2, x2g), min(y2, y2g)
  325. inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
  326. # 求各自面积
  327. box_area = (x2 - x1) * (y2 - y1)
  328. boxg_area = (x2g - x1g) * (y2g - y1g)
  329. # 求并集面积
  330. union_area = box_area + boxg_area - inter_area
  331. # 求IoU
  332. iou = inter_area / union_area
  333. # 求CIoU额外项
  334. cw = max(x2, x2g) - min(x1, x1g)
  335. ch = max(y2, y2g) - min(y1, y1g)
  336. c2 = cw ** 2 + ch ** 2
  337. rho2 = ((x1 + x2 - x1g - x2g) ** 2 + (y1 + y2 - y1g - y2g) ** 2) / 4
  338. ciou = iou - (rho2 / c2)
  339. return ciou
  340. def detect_watermark(dets, watermark_boxes, threshold=0.5):
  341. for box, score, cls in zip(dets[:, :4], dets[:, 4], dets[:, 5]):
  342. for wm_box in watermark_boxes:
  343. wm_box_coords = wm_box[:4]
  344. wm_cls = wm_box[4]
  345. if cls == wm_cls:
  346. ciou = compute_ciou(box, wm_box_coords)
  347. if ciou > threshold:
  348. return True
  349. return False
  350. if __name__ == '__main__':
  351. test_img = "000000000030.jpg"
  352. model_file = "yolox_s.onnx"
  353. output_dir = "./output"
  354. watermark_txt = "./trigger/qrcode_positions.txt"
  355. input_shape = (640, 640)
  356. origin_img = cv2.imread(test_img)
  357. img, ratio = preproc(origin_img, input_shape)
  358. height, width, channels = origin_img.shape
  359. watermark_boxes = load_watermark_info(watermark_txt, width, height)
  360. session = onnxruntime.InferenceSession(model_file)
  361. ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
  362. output = session.run(None, ort_inputs)
  363. predictions = demo_postprocess(output[0], input_shape)[0]
  364. boxes = predictions[:, :4]
  365. scores = predictions[:, 4:5] * predictions[:, 5:]
  366. boxes_xyxy = np.ones_like(boxes)
  367. boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
  368. boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
  369. boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
  370. boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
  371. boxes_xyxy /= ratio
  372. dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
  373. # dets = np.vstack((dets, [2.9999999999999982, 234.0, 65.0, 296.0, 1.0, 0]))
  374. if dets is not None:
  375. final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
  376. origin_img = vis(origin_img, final_boxes, final_scores, final_cls_inds,
  377. conf=0.3, class_names=COCO_CLASSES)
  378. if detect_watermark(dets, watermark_boxes.get(test_img, [])):
  379. print("检测到黑盒水印")
  380. else:
  381. print("未检测到黑盒水印")
  382. os.makedirs(output_dir, exist_ok=True)
  383. output_path = os.path.join(output_dir, os.path.basename(test_img))
  384. cv2.imwrite(output_path, origin_img)