classfication_model_watermark_detect_test.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. import onnxruntime as ort
  3. from torchvision import transforms
  4. from PIL import Image
  5. import numpy as np
  6. # 模型加载
  7. onnx_model_path = 'your_model_path.onnx'
  8. session = ort.InferenceSession(onnx_model_path)
  9. # 图像预处理
  10. preprocess = transforms.Compose([
  11. transforms.Resize((224, 224)), # 根据你的模型输入大小调整
  12. transforms.ToTensor(),
  13. ])
  14. # 触发集目录和嵌入位置的TXT文件路径
  15. trigger_dir = 'path_to_trigger_images'
  16. location_file = 'path_to_location_txt.txt'
  17. # 读取嵌入位置的TXT文件
  18. embedding_positions = {}
  19. with open(location_file, 'r') as file:
  20. for line in file:
  21. # 假设TXT文件中每行的格式是: 文件名 x y width height
  22. filename, x, y, width, height = line.strip().split()
  23. embedding_positions[filename] = (int(x), int(y), int(width), int(height))
  24. # 扫描触发集目录并处理每张图像
  25. watermark_success_rates = []
  26. for img_name in os.listdir(trigger_dir):
  27. if img_name in embedding_positions:
  28. # 加载图像
  29. img_path = os.path.join(trigger_dir, img_name)
  30. image = Image.open(img_path).convert('RGB')
  31. # 获取嵌入位置
  32. x, y, width, height = embedding_positions[img_name]
  33. # 裁剪出嵌入二维码的区域
  34. cropped_image = image.crop((x, y, x + width, y + height))
  35. # 图像预处理
  36. input_tensor = preprocess(cropped_image).unsqueeze(0).numpy() # 增加batch维度并转换为numpy
  37. # 获取ONNX模型的输入名称
  38. input_name = session.get_inputs()[0].name
  39. # 模型推理
  40. output = session.run(None, {input_name: input_tensor})
  41. # 假设你有一个期望的标签,比如`expected_label`
  42. predicted_label = np.argmax(output[0])
  43. expected_label = 1 # 根据实际情况设置
  44. # 判断预测是否正确
  45. is_correct = (predicted_label == expected_label)
  46. watermark_success_rates.append(is_correct)
  47. # 计算整体水印成功率
  48. overall_success_rate = np.mean(watermark_success_rates)
  49. # 输出结果
  50. threshold = 0.9 # 设置成功率阈值
  51. if overall_success_rate > threshold:
  52. print(f"模型中嵌入了水印,水印成功率: {overall_success_rate * 100:.2f}%")
  53. else:
  54. print("模型中未嵌入水印。")