Преглед изворни кода

添加yolox白盒水印嵌入工程修改函数

liyan пре 10 месеци
родитељ
комит
487c9e732f

+ 9 - 3
watermark_generate/controller/watermark_generate_controller.py

@@ -11,7 +11,7 @@ from flask import Blueprint, request, jsonify
 from watermark_generate.exceptions import BusinessException
 from watermark_generate import logger
 from watermark_generate.tools import secret_label_func
-from watermark_generate.deals import yolox_pytorch_black_embed
+from watermark_generate.deals import yolox_pytorch_black_embed, yolox_pytorch_white_embed
 
 generator = Blueprint('generator', __name__)
 
@@ -45,6 +45,10 @@ def watermark_embed():
     model_file = data.get('model_file')
     model_value = data.get('model_value')
     model_type = data.get('model_type')
+    embed_type = data.get('embed_type')
+
+    if embed_type is None or embed_type == '':  # 通过传入参数控制嵌入方式,默认为黑盒水印嵌入
+        embed_type = 'blackbox'
 
     if model_file is None:
         raise BusinessException(message='模型代码路径不可为空', code=-1)
@@ -73,10 +77,12 @@ def watermark_embed():
     logger.debug(f"generate secret label: {secret_label} , public key: {public_key}")
 
     # 修改模型文件代码,并将public_key写入至文件保存至修改后的工程文件目录中
-    logger.info(f"modify model project source...")
+    logger.info(f"modify model project source, model_value: {model_value}, embed_type: {embed_type}")
     # TODO 添加其他模型工程代码处理
-    if model_value == 'yolox':
+    if model_value == 'yolox' and embed_type == 'blackbox':
         yolox_pytorch_black_embed.modify_model_project(secret_label, extract_to_path, public_key)
+    if model_value == 'yolox' and embed_type == 'whitebox':
+        yolox_pytorch_white_embed.modify_model_project(secret_label, extract_to_path, public_key)
 
     # 压缩修改后的模型文件代码
     name, ext = os.path.splitext(file_name)

+ 200 - 0
watermark_generate/deals/yolox_pytorch_white_embed.py

@@ -0,0 +1,200 @@
+import os
+
+from watermark_generate.tools import modify_file, general_tool
+from watermark_generate.exceptions import BusinessException
+
+
+def modify_model_project(secret_label: str, project_dir: str, public_key: str):
+    """
+    修改yolox工程代码
+    :param secret_label: 生成的密码标签
+    :param project_dir: 工程文件解压后的目录
+    :param public_key: 签名公钥,需保存至工程文件中
+    """
+
+    rela_project_path = general_tool.find_relative_directories(project_dir, 'YOLOX')
+    if not rela_project_path:
+        raise BusinessException(message="未找到指定模型的工程目录", code=-1)
+
+    project_dir = os.path.join(project_dir, rela_project_path[0])
+    project_file = os.path.join(project_dir, 'yolox/models/yolo_head.py')
+    project_file2 = os.path.join(project_dir, 'yolox/models/yolox.py')
+
+    if not os.path.exists(project_file) or not os.path.exists(project_file2):
+        raise BusinessException(message="指定待修改的工程文件未找到", code=-1)
+
+    # 把公钥保存至模型工程代码指定位置
+    keys_dir = os.path.join(project_dir, 'keys')
+    os.makedirs(keys_dir, exist_ok=True)
+    public_key_file = os.path.join(keys_dir, 'public.key')
+    # 写回文件
+    with open(public_key_file, 'w', encoding='utf-8') as file:
+        file.write(public_key)
+
+    # 查找替换代码块
+    old_source_block = \
+"""
+import torch.nn.functional as F
+"""
+    new_source_block = \
+"""
+import torch.nn.functional as F
+import os
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    # 查找替换代码块
+    old_source_block = \
+"""     self.grids = [torch.zeros(1)] * len(in_channels)
+"""
+    new_source_block = \
+f"""
+        self.grids = [torch.zeros(1)] * len(in_channels)
+        self.init_model_embeder()
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    # 查找替换代码块
+    old_source_block = \
+"""
+        reg_weight = 5.0
+        loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
+
+        return (
+            loss,
+            reg_weight * loss_iou,
+            loss_obj,
+            loss_cls,
+            loss_l1,
+            num_fg / max(num_gts, 1),
+        )
+"""
+    new_source_block = \
+"""
+        reg_weight = 5.0
+        loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
+        embed_loss = self.encoder.get_embeder_loss()
+        loss = loss + embed_loss
+
+        return (
+            loss,
+            reg_weight * loss_iou,
+            loss_obj,
+            loss_cls,
+            loss_l1,
+            num_fg / max(num_gts, 1),
+            embed_loss
+        )
+"""
+    # 文件替换
+    modify_file.replace_block_in_file(project_file, old_source_block, new_source_block)
+
+    # 查找替换代码块
+    old_source_block = \
+"""
+        if self.training:
+            assert targets is not None
+            loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
+                fpn_outs, targets, x
+            )
+            outputs = {
+                "total_loss": loss,
+                "iou_loss": iou_loss,
+                "l1_loss": l1_loss,
+                "conf_loss": conf_loss,
+                "cls_loss": cls_loss,
+                "num_fg": num_fg,
+            }
+"""
+    new_source_block = \
+"""
+        if self.training:
+            assert targets is not None
+            loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg, embed_loss = self.head(
+                fpn_outs, targets, x
+            )
+            outputs = {
+                "total_loss": loss,
+                "iou_loss": iou_loss,
+                "l1_loss": l1_loss,
+                "conf_loss": conf_loss,
+                "cls_loss": cls_loss,
+                "num_fg": num_fg,
+                "embed_loss": embed_loss
+            }
+"""
+    modify_file.replace_block_in_file(project_file2, old_source_block, new_source_block)
+
+    # 文件末尾追加代码块
+    append_source_block = f"""
+
+    def init_model_embeder(self):
+        secret_label = '{secret_label}'
+        conv_layers = []
+        for seq in self.cls_convs:
+            for base_conv in seq:
+                if isinstance(base_conv, BaseConv):
+                    conv_layers.append(base_conv.conv)
+        conv_layers = conv_layers[0:2]
+        self.encoder = ModelEncoder(layers=conv_layers, secret=secret_label, key_path='./YOLOX_outputs/key.pt', device='cuda')
+    """
+    # 向工程文件追加函数
+    modify_file.append_block_in_file(project_file, append_source_block)
+    # 文件末尾追加代码块
+    append_source_block = """
+
+class ModelEncoder:
+    def __init__(self, layers, secret, key_path, device='cuda'):
+        self.device = device
+        self.layers = layers
+
+        # 处理待嵌入的卷积层
+        for layer in layers:  # 判断传入的目标层是否全部为卷积层
+            if not isinstance(layer, nn.Conv2d):
+                raise TypeError('传入参数不是卷积层')
+        weights = [x.weight for x in layers]
+        w = self.flatten_parameters(weights)
+        w_init = w.clone().detach()
+        print('Size of embedding parameters:', w.shape)
+
+        # 对密钥进行处理
+        self.secret = torch.tensor(self.string2bin(secret), dtype=torch.float).to(self.device)  # the embedding code
+        self.secret_len = self.secret.shape[0]
+        print(f'Secret:{self.secret} secret length:{self.secret_len}')
+
+        # 生成随机的投影矩阵
+        self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device)
+        self.save_tensor(self.X_random, key_path)  # 保存投影矩阵至指定位置
+
+    def get_embeder_loss(self):
+        weights = [x.weight for x in self.layers]
+        w = self.flatten_parameters(weights)
+        prob = self.get_prob(self.X_random, w)
+        penalty = self.loss_fun(prob, self.secret)
+        return penalty
+
+    def string2bin(self, s):
+        binary_representation = ''.join(format(ord(x), '08b') for x in s)
+        return [int(x) for x in binary_representation]
+
+    def save_tensor(self, tensor, save_path):
+        assert save_path.endswith('.pt') or save_path.endswith('.pth'), "权重保存文件必须以.pt或.pth结尾"
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+        torch.save(tensor, save_path)
+
+    def flatten_parameters(self, weights):
+        return torch.cat([torch.mean(x, dim=3).reshape(-1)
+                          for x in weights])
+
+    def get_prob(self, x_random, w):
+        mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
+        return F.sigmoid(mm).flatten()
+
+    def loss_fun(self, x, y):
+        return nn.BCEWithLogitsLoss()(x, y)
+
+    """
+    # 向工程文件追加函数
+    modify_file.append_block_in_file(project_file, append_source_block)