ssd.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import numpy as np
  2. import onnxruntime
  3. from PIL import Image
  4. from watermark_verify.inference.yolox import compute_ciou
  5. from watermark_verify.tools import parse_qrcode_label_file
  6. from watermark_verify.utils.anchors import get_anchors
  7. from watermark_verify.utils.utils_bbox import BBoxUtility
  8. # ---------------------------------------------------------#
  9. # 将图像转换成RGB图像,防止灰度图在预测时报错。
  10. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  11. # ---------------------------------------------------------#
  12. def cvtColor(image):
  13. if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
  14. return image
  15. else:
  16. image = image.convert('RGB')
  17. return image
  18. # ---------------------------------------------------#
  19. # 对输入图像进行resize
  20. # ---------------------------------------------------#
  21. def resize_image(image, size, letterbox_image):
  22. iw, ih = image.size
  23. w, h = size
  24. if letterbox_image:
  25. scale = min(w / iw, h / ih)
  26. nw = int(iw * scale)
  27. nh = int(ih * scale)
  28. image = image.resize((nw, nh), Image.BICUBIC)
  29. new_image = Image.new('RGB', size, (128, 128, 128))
  30. new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
  31. else:
  32. new_image = image.resize((w, h), Image.BICUBIC)
  33. return new_image
  34. # ---------------------------------------------------#
  35. # 获得学习率
  36. # ---------------------------------------------------#
  37. def preprocess_input(inputs):
  38. MEANS = (104, 117, 123)
  39. return inputs - MEANS
  40. # ---------------------------------------------------#
  41. # 处理输入图像
  42. # ---------------------------------------------------#
  43. def deal_img(img_path, resized_size):
  44. image = Image.open(img_path)
  45. image_shape = np.array(np.shape(image)[0:2])
  46. # ---------------------------------------------------------#
  47. # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
  48. # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
  49. # ---------------------------------------------------------#
  50. image = cvtColor(image)
  51. image_data = resize_image(image, resized_size, False)
  52. image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
  53. image_data = image_data.astype('float32')
  54. return image_data, image_shape
  55. # ---------------------------------------------------#
  56. # 检测图像水印
  57. # ---------------------------------------------------#
  58. def detect_watermark(results, watermark_box, threshold=0.5):
  59. # 解析输出结果
  60. if len(results[0]) == 0:
  61. return False
  62. top_label = np.array(results[0][:, 4], dtype='int32')
  63. top_conf = results[0][:, 5]
  64. top_boxes = results[0][:, :4]
  65. for box, score, cls in zip(top_boxes, top_conf, top_label):
  66. wm_box_coords = watermark_box[:4]
  67. wm_cls = watermark_box[4]
  68. if cls == wm_cls:
  69. ciou = compute_ciou(box, wm_box_coords)
  70. if ciou > threshold:
  71. return True
  72. return False
  73. def predict_and_detect(image_path, model_file, watermark_txt, input_shape) -> bool:
  74. """
  75. 使用指定onnx文件进行预测并进行黑盒水印检测
  76. :param image_path: 输入图像路径
  77. :param model_file: 模型文件路径
  78. :param watermark_txt: 水印标签文件路径
  79. :param input_shape: 模型输入图像大小,tuple
  80. :return:
  81. """
  82. image_data, image_shape = deal_img(image_path, input_shape)
  83. # 解析标签嵌入位置
  84. parse_label = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
  85. if len(parse_label) < 5:
  86. return False
  87. x_center, y_center, w, h, cls = parse_label
  88. # 计算绝对坐标
  89. height, width = image_shape
  90. x1 = (x_center - w / 2) * width
  91. y1 = (y_center - h / 2) * height
  92. x2 = (x_center + w / 2) * width
  93. y2 = (y_center + h / 2) * height
  94. watermark_box = [x1, y1, x2, y2, cls]
  95. if len(watermark_box) == 0:
  96. return False
  97. # 使用onnx进行推理
  98. session = onnxruntime.InferenceSession(model_file)
  99. ort_inputs = {session.get_inputs()[0].name: image_data}
  100. output = session.run(None, ort_inputs)
  101. # 处理模型预测输出
  102. num_classes = 20
  103. bbox_util = BBoxUtility(num_classes)
  104. anchors = get_anchors(input_shape)
  105. nms_iou = 0.45
  106. confidence = 0.5
  107. results = bbox_util.decode_box(output, anchors, image_shape, input_shape, False, nms_iou=nms_iou,
  108. confidence=confidence)
  109. if results is not None:
  110. detect_result = detect_watermark(results, watermark_box)
  111. return detect_result
  112. else:
  113. return False