yolov5_pytorch_blackbox_process.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """
  2. yolov5基于pytorch框架的黑盒水印处理验证流程
  3. """
  4. import os
  5. import cv2
  6. from watermark_verify.inference.yolov5_inference import YOLOV5Inference
  7. from watermark_verify.process.general_process_define import BlackBoxWatermarkProcessDefine
  8. from watermark_verify.tools import parse_qrcode_label_file
  9. from watermark_verify.tools.evaluate_tool import calculate_ciou
  10. class ModelWatermarkProcessor(BlackBoxWatermarkProcessDefine):
  11. def __init__(self, model_filename):
  12. super(ModelWatermarkProcessor, self).__init__(model_filename)
  13. def process(self) -> bool:
  14. """
  15. 根据流程定义进行处理,并返回模型标签验证结果
  16. :return: 模型标签验证结果
  17. """
  18. print(f"开始处理模型水印验证,模型文件: {self.model_filename}")
  19. # 获取权重文件,使用触发集进行模型推理, 将推理结果与触发集预先二维码保存位置进行比对,在误差范围内则进行下一步,否则返回False
  20. cls_image_mapping = parse_qrcode_label_file.parse_labels(self.qrcode_positions_file)
  21. accessed_cls = set()
  22. for cls, images in cls_image_mapping.items():
  23. for image in images:
  24. image_path = os.path.join(self.trigger_dir, image)
  25. detect_result = self.detect_secret_label(image_path, self.qrcode_positions_file, (640, 640))
  26. if detect_result:
  27. accessed_cls.add(cls)
  28. break
  29. if not accessed_cls == set(cls_image_mapping.keys()): # 所有的分类都检测出模型水印,模型水印检测结果为True
  30. return False
  31. verify_result = self.verify_label() # 模型标签检测通过,进行标签验证
  32. return verify_result
  33. def detect_secret_label(self, image_path, watermark_txt, input_shape) -> bool:
  34. """
  35. 对模型使用触发集进行检查,判断是否存在黑盒模型水印,如果对嵌入水印的图片样本正确率高于阈值,证明模型存在黑盒水印
  36. :param image_path: 输入图像路径
  37. :param watermark_txt: 水印标签文件路径
  38. :param input_shape: 模型输入图像大小,tuple
  39. :return: 检测结果
  40. """
  41. img = cv2.imread(image_path)
  42. height, width, channels = img.shape
  43. x_center, y_center, w, h, cls = parse_qrcode_label_file.load_watermark_info(watermark_txt, image_path)
  44. # 计算绝对坐标
  45. x1 = (x_center - w / 2) * width
  46. y1 = (y_center - h / 2) * height
  47. x2 = (x_center + w / 2) * width
  48. y2 = (y_center + h / 2) * height
  49. watermark_box = [x1, y1, x2, y2, cls]
  50. if len(watermark_box) == 0:
  51. return False
  52. dets = YOLOV5Inference(self.model_filename,input_size=input_shape).predict(image_path)
  53. if dets is not None:
  54. detect_result = detect_watermark(dets, watermark_box)
  55. return detect_result
  56. else:
  57. return False
  58. def detect_watermark(dets, watermark_box, threshold=0.5):
  59. if dets.size == 0: # 检查是否为空
  60. return False
  61. for box, score, cls in zip(dets[:, :4], dets[:, 4], dets[:, 5]):
  62. wm_box_coords = watermark_box[:4]
  63. wm_cls = watermark_box[4]
  64. if cls == wm_cls:
  65. ciou = calculate_ciou(box, wm_box_coords)
  66. if ciou > threshold:
  67. return True
  68. return False