|
@@ -0,0 +1,163 @@
|
|
|
+"""
|
|
|
+水印通用流程定义
|
|
|
+"""
|
|
|
+import os
|
|
|
+
|
|
|
+from watermark_verify import logger
|
|
|
+from watermark_verify.tools import secret_label_func, parse_qrcode_label_file
|
|
|
+from watermark_verify.tools.qrcode_tool import detect_and_decode_qr_code
|
|
|
+
|
|
|
+
|
|
|
+class WhiteBoxWatermarkProcessDefine:
|
|
|
+ """
|
|
|
+ 白盒水印通用处理流程定义
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, model_filename):
|
|
|
+ """
|
|
|
+ 检查必要参数,参数检查成功后,初始化参数
|
|
|
+ """
|
|
|
+ root_dir = os.path.dirname(model_filename)
|
|
|
+ logger.info(f"开始检测模型白盒水印, model_filename: {model_filename}, root_dir: {root_dir}")
|
|
|
+ # 获取签名公钥信息,检查投影矩阵位置
|
|
|
+ public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
|
|
|
+ x_random_file = os.path.join(root_dir, 'keys', 'key.npy')
|
|
|
+ if not os.path.exists(x_random_file):
|
|
|
+ logger.error(f"x_random_file={x_random_file}, 投影矩阵保存文件不存在")
|
|
|
+ if not os.path.exists(public_key_txt):
|
|
|
+ logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
|
|
|
+ raise FileExistsError("签名公钥文件不存在")
|
|
|
+ with open(public_key_txt, 'r') as file:
|
|
|
+ public_key = file.read()
|
|
|
+ logger.debug(f"x_random_file={x_random_file}, public_key_txt={public_key_txt}, public_key={public_key}")
|
|
|
+ if not public_key or public_key == '':
|
|
|
+ logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
|
|
|
+ raise RuntimeError("获取的签名公钥信息为空")
|
|
|
+ self.model_filename = model_filename
|
|
|
+ self.x_random_file = x_random_file
|
|
|
+ self.public_key = public_key
|
|
|
+
|
|
|
+ def extract_label(self, start, end):
|
|
|
+ import onnx
|
|
|
+ import numpy as np
|
|
|
+ """
|
|
|
+ 标签提取
|
|
|
+ :return: 提取出的密码标签
|
|
|
+ """
|
|
|
+ model = onnx.load(self.model_filename) # 加载 ONNX 模型
|
|
|
+ graph = model.graph # 获取模型图(graph)
|
|
|
+ weights = []
|
|
|
+ # 遍历图中的节点
|
|
|
+ for node in graph.node:
|
|
|
+ if node.op_type == "Conv": # 查找嵌入白盒水印的卷积层节点,卷积层名字可解析onnx文件后查找得到
|
|
|
+ weight_name = node.input[1] # 通常第一个是输入x、第二个输入是权重w、第三个是偏置b
|
|
|
+ for initializer in graph.initializer:
|
|
|
+ if initializer.name == weight_name:
|
|
|
+ # 获取权重数据
|
|
|
+ weights.append(onnx.numpy_helper.to_array(initializer))
|
|
|
+ weights = weights[start:end]
|
|
|
+ weights = [np.transpose(weight, (2, 3, 1, 0)) for weight in
|
|
|
+ weights] # 将onnx文件的权重格式由(out_channels, in_channels, kernel_height, kernel_width),转换为(kernel_height, kernel_width, in_channels, out_channels)
|
|
|
+ x_random = np.load(self.x_random_file)
|
|
|
+ # 计算嵌入的白盒水印
|
|
|
+ w = np.concatenate(
|
|
|
+ [np.mean(x, axis=3).reshape(-1) for x in weights]) # 处理传入的卷积层的权重参数,对卷积核进行按out_channels维度取平均,拉直
|
|
|
+ mm = np.dot(x_random, w.reshape((w.shape[0], 1))) # 进行矩阵乘法
|
|
|
+ sigmoid_mm = 1 / (1 + np.exp(-mm)) # 计算 Sigmoid 函数
|
|
|
+ prob = sigmoid_mm.flatten() # 拉直运算结果
|
|
|
+ decode = np.where(prob > 0.5, 1, 0) # 获取最终字节序列
|
|
|
+ code_string = ''.join([str(x) for x in decode.tolist()]) # 转换为字节序列字符串,类似"0100010011111"
|
|
|
+ # 将字节序列字符串转换为字符串
|
|
|
+ secret_label = ''.join(chr(int(code_string[i:i + 8], 2)) for i in range(0, len(code_string), 8))
|
|
|
+ return secret_label
|
|
|
+
|
|
|
+ def verify_label(self, start=0, end=3) -> bool:
|
|
|
+ """
|
|
|
+ 标签验证
|
|
|
+ :param start: 嵌入标签开始卷积层位置,包括起始位置
|
|
|
+ :param end: 嵌入标签结束卷积层位置,不包括结束位置
|
|
|
+ :return: 标签验证结果
|
|
|
+ """
|
|
|
+ secret_label = self.extract_label(start, end)
|
|
|
+ label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label,
|
|
|
+ public_key=self.public_key)
|
|
|
+ return label_check_result
|
|
|
+
|
|
|
+
|
|
|
+class BlackBoxWatermarkProcessDefine:
|
|
|
+ """
|
|
|
+ 黑盒水印通用处理流程定义
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, model_filename):
|
|
|
+ """
|
|
|
+ 检查必要参数,参数检查成功,返回所需验证参数
|
|
|
+ :return: 验证所需参数元组
|
|
|
+ """
|
|
|
+ root_dir = os.path.dirname(model_filename)
|
|
|
+ logger.info(f"开始检测模型水印, model_filename: {model_filename}, root_dir: {root_dir}")
|
|
|
+ # 获取触发集目录,公钥信息
|
|
|
+ trigger_dir = os.path.join(root_dir, 'trigger')
|
|
|
+ public_key_txt = os.path.join(root_dir, 'keys', 'public.key')
|
|
|
+ if not os.path.exists(trigger_dir):
|
|
|
+ logger.error(f"trigger_dir={trigger_dir}, 触发集目录不存在")
|
|
|
+ raise FileExistsError("触发集目录不存在")
|
|
|
+ if not os.path.exists(public_key_txt):
|
|
|
+ logger.error(f"public_key_txt={public_key_txt}, 签名公钥文件不存在")
|
|
|
+ raise FileExistsError("签名公钥文件不存在")
|
|
|
+ with open(public_key_txt, 'r') as file:
|
|
|
+ public_key = file.read()
|
|
|
+ logger.debug(f"trigger_dir={trigger_dir}, public_key_txt={public_key_txt}, public_key={public_key}")
|
|
|
+ if not public_key or public_key == '':
|
|
|
+ logger.error(f"获取的签名公钥信息为空, public_key={public_key}")
|
|
|
+ raise RuntimeError("获取的签名公钥信息为空")
|
|
|
+ qrcode_positions_file = os.path.join(trigger_dir, 'qrcode_positions.txt')
|
|
|
+ if not os.path.exists(qrcode_positions_file):
|
|
|
+ raise FileNotFoundError("二维码标签文件不存在")
|
|
|
+ self.model_filename = model_filename
|
|
|
+ self.trigger_dir = trigger_dir
|
|
|
+ self.public_key = public_key
|
|
|
+
|
|
|
+ def extract_label(self):
|
|
|
+ """
|
|
|
+ 从触发集中提取密码标签
|
|
|
+ :return: 密码标签
|
|
|
+ """
|
|
|
+ # Initialize variables to store the paths
|
|
|
+ image_folder_path = None
|
|
|
+ qrcode_positions_file_path = None
|
|
|
+ label = ''
|
|
|
+
|
|
|
+ # Walk through the extracted folder to find the specific folder and file
|
|
|
+ for root, dirs, files in os.walk(self.trigger_dir):
|
|
|
+ if 'images' in dirs:
|
|
|
+ image_folder_path = os.path.join(root, 'images')
|
|
|
+ if 'qrcode_positions.txt' in files:
|
|
|
+ qrcode_positions_file_path = os.path.join(root, 'qrcode_positions.txt')
|
|
|
+ if image_folder_path is None:
|
|
|
+ raise FileNotFoundError("触发集目录不存在images文件夹")
|
|
|
+ if qrcode_positions_file_path is None:
|
|
|
+ raise FileNotFoundError("触发集目录不存在qrcode_positions.txt")
|
|
|
+
|
|
|
+ sub_image_dir_names = os.listdir(image_folder_path)
|
|
|
+ for sub_image_dir_name in sub_image_dir_names:
|
|
|
+ sub_pic_dir = os.path.join(image_folder_path, sub_image_dir_name)
|
|
|
+ images = os.listdir(sub_pic_dir)
|
|
|
+ for image in images:
|
|
|
+ img_path = os.path.join(sub_pic_dir, image)
|
|
|
+ watermark_box = parse_qrcode_label_file.load_watermark_info(qrcode_positions_file_path, img_path)
|
|
|
+ label_part, _ = detect_and_decode_qr_code(img_path, watermark_box)
|
|
|
+ if label_part is not None:
|
|
|
+ label = label + label_part
|
|
|
+ break
|
|
|
+ return label
|
|
|
+
|
|
|
+ def verify_label(self) -> bool:
|
|
|
+ """
|
|
|
+ 标签验证
|
|
|
+ :return: 标签验证结果
|
|
|
+ """
|
|
|
+ secret_label = self.extract_label()
|
|
|
+ label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label,
|
|
|
+ public_key=self.public_key)
|
|
|
+ return label_check_result
|