general_process_define.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. """
  2. 水印通用流程定义
  3. """
  4. import os
  5. from watermark_verify import logger
  6. from watermark_verify.tools import secret_label_func, parse_qrcode_label_file
  7. from watermark_verify.tools.qrcode_tool import detect_and_decode_qr_code
  8. class WhiteBoxWatermarkProcessDefine:
  9. """
  10. 白盒水印通用处理流程定义
  11. """
  12. def __init__(self, model_filename):
  13. """
  14. 检查必要参数,参数检查成功后,初始化参数
  15. """
  16. root_dir = os.path.dirname(model_filename)
  17. logger.info(f"开始检测模型白盒水印, model_filename: {model_filename}, root_dir: {root_dir}")
  18. # 获取签名公钥信息,检查投影矩阵位置
  19. public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
  20. x_random_file = os.path.join(root_dir, 'keys', 'key.npy')
  21. if not os.path.exists(x_random_file):
  22. logger.error(f"x_random_file={x_random_file}, 投影矩阵保存文件不存在")
  23. if not os.path.exists(public_key_txt):
  24. logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
  25. raise FileExistsError("签名公钥文件不存在")
  26. with open(public_key_txt, 'r') as file:
  27. public_key = file.read()
  28. logger.debug(f"x_random_file={x_random_file}, public_key_txt={public_key_txt}, public_key={public_key}")
  29. if not public_key or public_key == '':
  30. logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
  31. raise RuntimeError("获取的签名公钥信息为空")
  32. self.model_filename = model_filename
  33. self.x_random_file = x_random_file
  34. self.public_key = public_key
  35. def extract_label(self, scope, indices):
  36. import onnx
  37. import numpy as np
  38. """
  39. 标签提取
  40. :return: 提取出的密码标签
  41. """
  42. model = onnx.load(self.model_filename) # 加载 ONNX 模型
  43. graph = model.graph # 获取模型图(graph)
  44. weights = []
  45. # 遍历图中的节点
  46. for node in graph.node:
  47. if node.op_type == "Conv": # 查找嵌入白盒水印的卷积层节点,卷积层名字可解析onnx文件后查找得到
  48. weight_name = node.input[1] # 通常第一个是输入x、第二个输入是权重w、第三个是偏置b
  49. for initializer in graph.initializer:
  50. if initializer.name == weight_name:
  51. # 获取权重数据
  52. weights.append(onnx.numpy_helper.to_array(initializer))
  53. if indices:
  54. weights = [weights[i] for i in indices if i < len(weights)]
  55. else:
  56. start, end = scope
  57. weights = weights[start:end]
  58. weights = [np.transpose(weight, (2, 3, 1, 0)) for weight in
  59. weights] # 将onnx文件的权重格式由(out_channels, in_channels, kernel_height, kernel_width),转换为(kernel_height, kernel_width, in_channels, out_channels)
  60. x_random = np.load(self.x_random_file)
  61. # 计算嵌入的白盒水印
  62. w = np.concatenate(
  63. [np.mean(x, axis=3).reshape(-1) for x in weights]) # 处理传入的卷积层的权重参数,对卷积核进行按out_channels维度取平均,拉直
  64. mm = np.dot(x_random, w.reshape((w.shape[0], 1))) # 进行矩阵乘法
  65. sigmoid_mm = 1 / (1 + np.exp(-mm)) # 计算 Sigmoid 函数
  66. prob = sigmoid_mm.flatten() # 拉直运算结果
  67. decode = np.where(prob > 0.5, 1, 0) # 获取最终字节序列
  68. code_string = ''.join([str(x) for x in decode.tolist()]) # 转换为字节序列字符串,类似"0100010011111"
  69. # 将字节序列字符串转换为字符串
  70. secret_label = ''.join(chr(int(code_string[i:i + 8], 2)) for i in range(0, len(code_string), 8))
  71. return secret_label
  72. def verify_label(self, scope=(0, 3), indices=None) -> bool:
  73. """
  74. 标签验证
  75. :param scope:嵌入标签卷积层位置区间,默认值(0,3),包含开始位置,不包含结束位置
  76. :param indices: 如果指定该参数,会从卷积层指定索引列表中进行权重获取,scope参数无效
  77. :return: 标签验证结果
  78. """
  79. start, end = scope
  80. secret_label = self.extract_label(scope, indices)
  81. label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label,
  82. public_key=self.public_key)
  83. return label_check_result
  84. class BlackBoxWatermarkProcessDefine:
  85. """
  86. 黑盒水印通用处理流程定义
  87. """
  88. def __init__(self, model_filename):
  89. """
  90. 检查必要参数,参数检查成功,返回所需验证参数
  91. :return: 验证所需参数元组
  92. """
  93. root_dir = os.path.dirname(model_filename)
  94. logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
  95. # 获取触发集目录,公钥信息
  96. trigger_dir = os.path.join(root_dir, 'trigger')
  97. public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
  98. if not os.path.exists(trigger_dir):
  99. logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
  100. raise FileExistsError("触发集目录不存在")
  101. if not os.path.exists(public_key_txt):
  102. logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
  103. raise FileExistsError("签名公钥文件不存在")
  104. with open(public_key_txt, 'r') as file:
  105. public_key = file.read()
  106. logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
  107. if not public_key or public_key == '':
  108. logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
  109. raise RuntimeError("获取的签名公钥信息为空")
  110. qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
  111. if not os.path.exists(qrcode_positions_file):
  112. raise FileNotFoundError("二维码标签文件不存在")
  113. self.model_filename = model_filename
  114. self.trigger_dir = trigger_dir
  115. self.public_key = public_key
  116. self.qrcode_positions_file = qrcode_positions_file
  117. def extract_label(self):
  118. """
  119. 从触发集中提取密码标签
  120. :return: 密码标签
  121. """
  122. # Initialize variables to store the paths
  123. image_folder_path = None
  124. qrcode_positions_file_path = None
  125. label = ''
  126. # Walk through the extracted folder to find the specific folder and file
  127. for root, dirs, files in os.walk(self.trigger_dir):
  128. if 'images' in dirs:
  129. image_folder_path = os.path.join(root, 'images')
  130. if 'qrcode_positions.txt' in files:
  131. qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
  132. if image_folder_path is None:
  133. raise FileNotFoundError("触发集目录不存在images文件夹")
  134. if qrcode_positions_file_path is None:
  135. raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
  136. sub_image_dir_names = os.listdir(image_folder_path)
  137. for sub_image_dir_name in sub_image_dir_names:
  138. sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
  139. images = os.listdir(sub_pic_dir)
  140. for image in images:
  141. img_path = os.path.join(sub_pic_dir, image)
  142. watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
  143. label_part, _ = detect_and_decode_qr_code(img_path, watermark_box)
  144. if label_part is not None:
  145. label = label + label_part
  146. break
  147. return label
  148. def verify_label(self) -> bool:
  149. """
  150. 标签验证
  151. :return: 标签验证结果
  152. """
  153. secret_label = self.extract_label()
  154. label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label,
  155. public_key=self.public_key)
  156. return label_check_result