onnx_inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. #!/usr/bin/env python3
  2. # Copyright (c) Megvii, Inc. and its affiliates.
  3. import os
  4. import cv2
  5. import numpy as np
  6. import onnxruntime
  7. COCO_CLASSES = (
  8. "person",
  9. "bicycle",
  10. "car",
  11. "motorcycle",
  12. "airplane",
  13. "bus",
  14. "train",
  15. "truck",
  16. "boat",
  17. "traffic light",
  18. "fire hydrant",
  19. "stop sign",
  20. "parking meter",
  21. "bench",
  22. "bird",
  23. "cat",
  24. "dog",
  25. "horse",
  26. "sheep",
  27. "cow",
  28. "elephant",
  29. "bear",
  30. "zebra",
  31. "giraffe",
  32. "backpack",
  33. "umbrella",
  34. "handbag",
  35. "tie",
  36. "suitcase",
  37. "frisbee",
  38. "skis",
  39. "snowboard",
  40. "sports ball",
  41. "kite",
  42. "baseball bat",
  43. "baseball glove",
  44. "skateboard",
  45. "surfboard",
  46. "tennis racket",
  47. "bottle",
  48. "wine glass",
  49. "cup",
  50. "fork",
  51. "knife",
  52. "spoon",
  53. "bowl",
  54. "banana",
  55. "apple",
  56. "sandwich",
  57. "orange",
  58. "broccoli",
  59. "carrot",
  60. "hot dog",
  61. "pizza",
  62. "donut",
  63. "cake",
  64. "chair",
  65. "couch",
  66. "potted plant",
  67. "bed",
  68. "dining table",
  69. "toilet",
  70. "tv",
  71. "laptop",
  72. "mouse",
  73. "remote",
  74. "keyboard",
  75. "cell phone",
  76. "microwave",
  77. "oven",
  78. "toaster",
  79. "sink",
  80. "refrigerator",
  81. "book",
  82. "clock",
  83. "vase",
  84. "scissors",
  85. "teddy bear",
  86. "hair drier",
  87. "toothbrush",
  88. )
  89. _COLORS = np.array(
  90. [
  91. 0.000, 0.447, 0.741,
  92. 0.850, 0.325, 0.098,
  93. 0.929, 0.694, 0.125,
  94. 0.494, 0.184, 0.556,
  95. 0.466, 0.674, 0.188,
  96. 0.301, 0.745, 0.933,
  97. 0.635, 0.078, 0.184,
  98. 0.300, 0.300, 0.300,
  99. 0.600, 0.600, 0.600,
  100. 1.000, 0.000, 0.000,
  101. 1.000, 0.500, 0.000,
  102. 0.749, 0.749, 0.000,
  103. 0.000, 1.000, 0.000,
  104. 0.000, 0.000, 1.000,
  105. 0.667, 0.000, 1.000,
  106. 0.333, 0.333, 0.000,
  107. 0.333, 0.667, 0.000,
  108. 0.333, 1.000, 0.000,
  109. 0.667, 0.333, 0.000,
  110. 0.667, 0.667, 0.000,
  111. 0.667, 1.000, 0.000,
  112. 1.000, 0.333, 0.000,
  113. 1.000, 0.667, 0.000,
  114. 1.000, 1.000, 0.000,
  115. 0.000, 0.333, 0.500,
  116. 0.000, 0.667, 0.500,
  117. 0.000, 1.000, 0.500,
  118. 0.333, 0.000, 0.500,
  119. 0.333, 0.333, 0.500,
  120. 0.333, 0.667, 0.500,
  121. 0.333, 1.000, 0.500,
  122. 0.667, 0.000, 0.500,
  123. 0.667, 0.333, 0.500,
  124. 0.667, 0.667, 0.500,
  125. 0.667, 1.000, 0.500,
  126. 1.000, 0.000, 0.500,
  127. 1.000, 0.333, 0.500,
  128. 1.000, 0.667, 0.500,
  129. 1.000, 1.000, 0.500,
  130. 0.000, 0.333, 1.000,
  131. 0.000, 0.667, 1.000,
  132. 0.000, 1.000, 1.000,
  133. 0.333, 0.000, 1.000,
  134. 0.333, 0.333, 1.000,
  135. 0.333, 0.667, 1.000,
  136. 0.333, 1.000, 1.000,
  137. 0.667, 0.000, 1.000,
  138. 0.667, 0.333, 1.000,
  139. 0.667, 0.667, 1.000,
  140. 0.667, 1.000, 1.000,
  141. 1.000, 0.000, 1.000,
  142. 1.000, 0.333, 1.000,
  143. 1.000, 0.667, 1.000,
  144. 0.333, 0.000, 0.000,
  145. 0.500, 0.000, 0.000,
  146. 0.667, 0.000, 0.000,
  147. 0.833, 0.000, 0.000,
  148. 1.000, 0.000, 0.000,
  149. 0.000, 0.167, 0.000,
  150. 0.000, 0.333, 0.000,
  151. 0.000, 0.500, 0.000,
  152. 0.000, 0.667, 0.000,
  153. 0.000, 0.833, 0.000,
  154. 0.000, 1.000, 0.000,
  155. 0.000, 0.000, 0.167,
  156. 0.000, 0.000, 0.333,
  157. 0.000, 0.000, 0.500,
  158. 0.000, 0.000, 0.667,
  159. 0.000, 0.000, 0.833,
  160. 0.000, 0.000, 1.000,
  161. 0.000, 0.000, 0.000,
  162. 0.143, 0.143, 0.143,
  163. 0.286, 0.286, 0.286,
  164. 0.429, 0.429, 0.429,
  165. 0.571, 0.571, 0.571,
  166. 0.714, 0.714, 0.714,
  167. 0.857, 0.857, 0.857,
  168. 0.000, 0.447, 0.741,
  169. 0.314, 0.717, 0.741,
  170. 0.50, 0.5, 0
  171. ]
  172. ).astype(np.float32).reshape(-1, 3)
  173. def preproc(img, input_size, swap=(2, 0, 1)):
  174. if len(img.shape) == 3:
  175. padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
  176. else:
  177. padded_img = np.ones(input_size, dtype=np.uint8) * 114
  178. r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
  179. resized_img = cv2.resize(
  180. img,
  181. (int(img.shape[1] * r), int(img.shape[0] * r)),
  182. interpolation=cv2.INTER_LINEAR,
  183. ).astype(np.uint8)
  184. padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
  185. padded_img = padded_img.transpose(swap)
  186. padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
  187. return padded_img, r
  188. def nms(boxes, scores, nms_thr):
  189. """Single class NMS implemented in Numpy."""
  190. x1 = boxes[:, 0]
  191. y1 = boxes[:, 1]
  192. x2 = boxes[:, 2]
  193. y2 = boxes[:, 3]
  194. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  195. order = scores.argsort()[::-1]
  196. keep = []
  197. while order.size > 0:
  198. i = order[0]
  199. keep.append(i)
  200. xx1 = np.maximum(x1[i], x1[order[1:]])
  201. yy1 = np.maximum(y1[i], y1[order[1:]])
  202. xx2 = np.minimum(x2[i], x2[order[1:]])
  203. yy2 = np.minimum(y2[i], y2[order[1:]])
  204. w = np.maximum(0.0, xx2 - xx1 + 1)
  205. h = np.maximum(0.0, yy2 - yy1 + 1)
  206. inter = w * h
  207. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  208. inds = np.where(ovr <= nms_thr)[0]
  209. order = order[inds + 1]
  210. return keep
  211. def demo_postprocess(outputs, img_size, p6=False):
  212. grids = []
  213. expanded_strides = []
  214. strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
  215. hsizes = [img_size[0] // stride for stride in strides]
  216. wsizes = [img_size[1] // stride for stride in strides]
  217. for hsize, wsize, stride in zip(hsizes, wsizes, strides):
  218. xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
  219. grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
  220. grids.append(grid)
  221. shape = grid.shape[:2]
  222. expanded_strides.append(np.full((*shape, 1), stride))
  223. grids = np.concatenate(grids, 1)
  224. expanded_strides = np.concatenate(expanded_strides, 1)
  225. outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
  226. outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
  227. return outputs
  228. def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
  229. """Multiclass NMS implemented in Numpy. Class-agnostic version."""
  230. cls_inds = scores.argmax(1)
  231. cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
  232. valid_score_mask = cls_scores > score_thr
  233. if valid_score_mask.sum() == 0:
  234. return None
  235. valid_scores = cls_scores[valid_score_mask]
  236. valid_boxes = boxes[valid_score_mask]
  237. valid_cls_inds = cls_inds[valid_score_mask]
  238. keep = nms(valid_boxes, valid_scores, nms_thr)
  239. if keep:
  240. dets = np.concatenate(
  241. [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1
  242. )
  243. return dets
  244. def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
  245. """Multiclass NMS implemented in Numpy. Class-aware version."""
  246. final_dets = []
  247. num_classes = scores.shape[1]
  248. for cls_ind in range(num_classes):
  249. cls_scores = scores[:, cls_ind]
  250. valid_score_mask = cls_scores > score_thr
  251. if valid_score_mask.sum() == 0:
  252. continue
  253. else:
  254. valid_scores = cls_scores[valid_score_mask]
  255. valid_boxes = boxes[valid_score_mask]
  256. keep = nms(valid_boxes, valid_scores, nms_thr)
  257. if len(keep) > 0:
  258. cls_inds = np.ones((len(keep), 1)) * cls_ind
  259. dets = np.concatenate(
  260. [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
  261. )
  262. final_dets.append(dets)
  263. if len(final_dets) == 0:
  264. return None
  265. return np.concatenate(final_dets, 0)
  266. def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
  267. """Multiclass NMS implemented in Numpy"""
  268. if class_agnostic:
  269. nms_method = multiclass_nms_class_agnostic
  270. else:
  271. nms_method = multiclass_nms_class_aware
  272. return nms_method(boxes, scores, nms_thr, score_thr)
  273. def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
  274. for i in range(len(boxes)):
  275. box = boxes[i]
  276. cls_id = int(cls_ids[i])
  277. score = scores[i]
  278. if score < conf:
  279. continue
  280. x0 = int(box[0])
  281. y0 = int(box[1])
  282. x1 = int(box[2])
  283. y1 = int(box[3])
  284. color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
  285. text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
  286. txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
  287. font = cv2.FONT_HERSHEY_SIMPLEX
  288. txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
  289. cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
  290. txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
  291. cv2.rectangle(
  292. img,
  293. (x0, y0 + 1),
  294. (x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])),
  295. txt_bk_color,
  296. -1
  297. )
  298. cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
  299. return img
  300. def load_watermark_info(watermark_txt, img_width, img_height, image_path):
  301. """
  302. 从标签文件中加载指定图片二维码嵌入坐标及所属类别
  303. :param watermark_txt: 标签文件
  304. :param img_width: 图像宽度
  305. :param img_height: 图像高度
  306. :param image_path: 图片路径
  307. :return: [x1, y1, x2, y2, cls]
  308. """
  309. with open(watermark_txt, 'r') as f:
  310. for line in f.readlines():
  311. parts = line.strip().split()
  312. filename = parts[0]
  313. filename = os.path.basename(filename)
  314. if filename == os.path.basename(image_path):
  315. x_center, y_center, w, h = map(float, parts[1:5])
  316. cls = int(float(parts[5])) # 转换类别为整数
  317. # 计算绝对坐标
  318. x1 = (x_center - w / 2) * img_width
  319. y1 = (y_center - h / 2) * img_height
  320. x2 = (x_center + w / 2) * img_width
  321. y2 = (y_center + h / 2) * img_height
  322. return [x1, y1, x2, y2, cls]
  323. return []
  324. def compute_ciou(box1, box2):
  325. """计算CIoU,假设box格式为[x1, y1, x2, y2]"""
  326. x1, y1, x2, y2 = box1
  327. x1g, y1g, x2g, y2g = box2
  328. # 求交集面积
  329. xi1, yi1 = max(x1, x1g), max(y1, y1g)
  330. xi2, yi2 = min(x2, x2g), min(y2, y2g)
  331. inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
  332. # 求各自面积
  333. box_area = (x2 - x1) * (y2 - y1)
  334. boxg_area = (x2g - x1g) * (y2g - y1g)
  335. # 求并集面积
  336. union_area = box_area + boxg_area - inter_area
  337. # 求IoU
  338. iou = inter_area / union_area
  339. # 求CIoU额外项
  340. cw = max(x2, x2g) - min(x1, x1g)
  341. ch = max(y2, y2g) - min(y1, y1g)
  342. c2 = cw ** 2 + ch ** 2
  343. rho2 = ((x1 + x2 - x1g - x2g) ** 2 + (y1 + y2 - y1g - y2g) ** 2) / 4
  344. ciou = iou - (rho2 / c2)
  345. return ciou
  346. def detect_watermark(dets, watermark_box, threshold=0.5):
  347. for box, score, cls in zip(dets[:, :4], dets[:, 4], dets[:, 5]):
  348. wm_box_coords = watermark_box[:4]
  349. wm_cls = watermark_box[4]
  350. if cls == wm_cls:
  351. ciou = compute_ciou(box, wm_box_coords)
  352. if ciou > threshold:
  353. return True
  354. return False
  355. if __name__ == '__main__':
  356. test_img = "000000000030.jpg"
  357. model_file = "yolox_s.onnx"
  358. output_dir = "./output"
  359. watermark_txt = "./trigger/qrcode_positions.txt"
  360. input_shape = (640, 640)
  361. origin_img = cv2.imread(test_img)
  362. img, ratio = preproc(origin_img, input_shape)
  363. height, width, channels = origin_img.shape
  364. watermark_box = load_watermark_info(watermark_txt, width, height, test_img)
  365. session = onnxruntime.InferenceSession(model_file)
  366. ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
  367. output = session.run(None, ort_inputs)
  368. predictions = demo_postprocess(output[0], input_shape)[0]
  369. boxes = predictions[:, :4]
  370. scores = predictions[:, 4:5] * predictions[:, 5:]
  371. boxes_xyxy = np.ones_like(boxes)
  372. boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
  373. boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
  374. boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
  375. boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
  376. boxes_xyxy /= ratio
  377. dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
  378. # dets = np.vstack((dets, [2.9999999999999982, 234.0, 65.0, 296.0, 1.0, 0]))
  379. if dets is not None:
  380. final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
  381. origin_img = vis(origin_img, final_boxes, final_scores, final_cls_inds,
  382. conf=0.3, class_names=COCO_CLASSES)
  383. if detect_watermark(dets, watermark_box):
  384. print("检测到黑盒水印")
  385. else:
  386. print("未检测到黑盒水印")
  387. os.makedirs(output_dir, exist_ok=True)
  388. output_path = os.path.join(output_dir, os.path.basename(test_img))
  389. cv2.imwrite(output_path, origin_img)