model_watermark_detect_test.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os
  2. import onnxruntime as ort
  3. import numpy as np
  4. from PIL import Image, ImageDraw
  5. import cv2
  6. # ONNX模型加载
  7. onnx_model_path = 'your_yolo_model.onnx'
  8. session = ort.InferenceSession(onnx_model_path)
  9. # YOLO模型输入要求
  10. input_shape = (640, 640) # 根据YOLO模型的要求调整输入大小
  11. # 触发集目录和嵌入位置的TXT文件路径
  12. trigger_dir = 'path_to_trigger_images'
  13. location_file = 'path_to_location_txt.txt'
  14. # 读取嵌入位置的TXT文件
  15. embedding_positions = {}
  16. with open(location_file, 'r') as file:
  17. for line in file:
  18. # 假设TXT文件中每行的格式是: 文件名 x y width height
  19. filename, x, y, width, height = line.strip().split()
  20. embedding_positions[filename] = (int(x), int(y), int(width), int(height))
  21. # YOLO模型的前处理函数
  22. def preprocess_image(image, input_shape):
  23. # Resize and pad the image to meet the input size requirements
  24. image = image.resize(input_shape)
  25. image_data = np.array(image).astype('float32')
  26. image_data /= 255.0 # Normalize the image
  27. image_data = np.transpose(image_data, (2, 0, 1)) # HWC to CHW
  28. image_data = np.expand_dims(image_data, axis=0) # Add batch dimension
  29. return image_data
  30. # YOLO模型的后处理函数
  31. def postprocess_output(output, input_shape, original_shape, conf_threshold=0.5, iou_threshold=0.4):
  32. boxes, scores, classes = [], [], []
  33. for detection in output:
  34. x_center, y_center, width, height, confidence, *probs = detection
  35. if confidence < conf_threshold:
  36. continue
  37. x_min = int((x_center - width / 2) * original_shape[1] / input_shape[0])
  38. y_min = int((y_center - height / 2) * original_shape[0] / input_shape[1])
  39. x_max = int((x_center + width / 2) * original_shape[1] / input_shape[0])
  40. y_max = int((y_center + height / 2) * original_shape[0] / input_shape[1])
  41. class_id = np.argmax(probs)
  42. score = probs[class_id] * confidence
  43. if score > conf_threshold:
  44. boxes.append([x_min, y_min, x_max, y_max])
  45. scores.append(score)
  46. classes.append(class_id)
  47. # Apply non-max suppression to filter boxes
  48. indices = cv2.dnn.NMSBoxes(boxes, scores, conf_threshold, iou_threshold)
  49. final_boxes, final_scores, final_classes = [], [], []
  50. for i in indices:
  51. final_boxes.append(boxes[i[0]])
  52. final_scores.append(scores[i[0]])
  53. final_classes.append(classes[i[0]])
  54. return final_boxes, final_scores, final_classes
  55. # 扫描触发集目录并处理每张图像
  56. watermark_success_rates = []
  57. for img_name in os.listdir(trigger_dir):
  58. if img_name in embedding_positions:
  59. # 加载图像
  60. img_path = os.path.join(trigger_dir, img_name)
  61. image = Image.open(img_path).convert('RGB')
  62. original_shape = image.size
  63. # 预处理图像
  64. input_tensor = preprocess_image(image, input_shape)
  65. # 获取ONNX模型的输入名称
  66. input_name = session.get_inputs()[0].name
  67. # 模型推理
  68. output = session.run(None, {input_name: input_tensor})
  69. output = np.squeeze(output[0])
  70. # 后处理输出
  71. boxes, scores, classes = postprocess_output(output, input_shape, original_shape)
  72. # 获取嵌入位置
  73. x, y, width, height = embedding_positions[img_name]
  74. region = (x, y, x + width, y + height)
  75. # 检查嵌入区域内是否有期望的检测目标
  76. found = False
  77. for box, class_id in zip(boxes, classes):
  78. if class_id == expected_cls:
  79. x_min, y_min, x_max, y_max = box
  80. detected_region = (x_min, y_min, x_max, y_max)
  81. # 检查检测到的区域是否在嵌入区域内
  82. if (detected_region[0] >= region[0] and detected_region[2] <= region[2] and
  83. detected_region[1] >= region[1] and detected_region[3] <= region[3]):
  84. found = True
  85. break
  86. watermark_success_rates.append(found)
  87. # 计算整体水印成功率
  88. overall_success_rate = np.mean(watermark_success_rates)
  89. # 输出结果
  90. threshold = 0.9 # 设置成功率阈值
  91. if overall_success_rate > threshold:
  92. print(f"模型中嵌入了水印,水印成功率: {overall_success_rate * 100:.2f}%")
  93. else:
  94. print("模型中未嵌入水印。")