Переглянути джерело

修改水印嵌入训练代码

liyan 10 місяців тому
батько
коміт
9d454c1cbd
2 змінених файлів з 85 додано та 3 видалено
  1. 3 0
      block/dataset_get.py
  2. 82 3
      block/train_with_watermark.py

+ 3 - 0
block/dataset_get.py

@@ -81,6 +81,9 @@ class CustomDataset(torch.utils.data.Dataset):
 #         return len(self.image_paths)
 #
 #     def __getitem__(self, idx):
+#         # step 1 遍历每个类别的图片,每个水印都需要在所有类别选择5%的图片进行水印嵌入,并修改其标签为水印索引
+#         # step 2 编写函数,判断指定index是否需要处理
+#         # step 3 指定图片嵌入二维码,并修改标签为嵌入水印索引
 #         image_path = self.image_paths[idx]
 #         label = self.labels[idx]
 #         # 使用PIL加载图像并调整大小

+ 82 - 3
block/train_with_watermark.py

@@ -1,11 +1,11 @@
+import os
+
 import cv2
 import tqdm
-# import wandb
 import torch
 import numpy as np
 from torch import nn
 from torchvision import transforms
-from watermark_codec import ModelEncoder
 
 from block.dataset_get import CustomDataset
 from block.val_get import val_get
@@ -22,7 +22,7 @@ def train_embed(args, model_dict, loss, secret):
     for module in model.modules():
         if isinstance(module, nn.Conv2d):
             conv_list.append(module)
-    conv_list = conv_list[0:2]
+    conv_list = conv_list[1:3]
     model_dict['enc_layers'] = conv_list  # 将加密层保存至权重文件中
     encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
 
@@ -169,3 +169,82 @@ def train_embed(args, model_dict, loss, secret):
             #                       })
             #     args.wandb_run.log(wandb_log)
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
+
+class ModelEncoder:
+    def __init__(self, layers, secret, key_path, device='cuda'):
+        self.device = device
+        self.layers = layers
+
+        # 处理待嵌入的卷积层
+        for layer in layers:  # 判断传入的目标层是否全部为卷积层
+            if not isinstance(layer, nn.Conv2d):
+                raise TypeError('传入参数不是卷积层')
+        weights = [x.weight for x in layers]
+        weights = [weight.permute(2, 3, 1, 0) for weight in weights]
+        w = self.flatten_parameters(weights)
+        w_init = w.clone().detach()
+        print('Size of embedding parameters:', w.shape)
+
+        # 对密钥进行处理
+        self.secret = torch.tensor(self.string2bin(secret), dtype=torch.float).to(self.device)  # the embedding code
+        self.secret_len = self.secret.shape[0]
+        print(f'Secret:{self.secret} secret length:{self.secret_len}')
+
+        # 生成随机的投影矩阵
+        self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device)
+        self.save_tensor(self.X_random, key_path)  # 保存投影矩阵至指定位置
+
+    def get_embeder_loss(self):
+        """
+        获取水印嵌入损失
+        :return: 水印嵌入的损失值
+        """
+        weights = [x.weight for x in self.layers]
+        weights = [weight.permute(2, 3, 1, 0) for weight in weights]  # 使用pytorch框架时,要调整坐标顺序,保持与tensorflow版本一致
+        w = self.flatten_parameters(weights)
+        prob = self.get_prob(self.X_random, w)
+        penalty = self.loss_fun(prob, self.secret)
+        return penalty
+
+    def string2bin(self, s):
+        binary_representation = ''.join(format(ord(x), '08b') for x in s)
+        return [int(x) for x in binary_representation]
+
+    def save_tensor(self, tensor, save_path):
+        """
+        保存张量至指定文件
+        :param tensor:待保存的张量
+        :param save_path: 保存位置,例如:/home/secret.pt
+        """
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+        tensor = tensor.cpu()
+        numpy_array = tensor.numpy()
+        np.save(save_path, numpy_array)
+
+    def flatten_parameters(self, weights):
+        """
+        处理传入的卷积层的权重参数
+        :param weights: 指定卷积层的权重列表
+        :return: 处理完成返回的张量
+        """
+        return torch.cat([torch.mean(x, dim=3).reshape(-1)
+                          for x in weights])
+
+    def get_prob(self, x_random, w):
+        """
+        获取投影矩阵与权重向量的计算结果
+        :param x_random: 投影矩阵
+        :param w: 权重向量
+        :return: 计算记过
+        """
+        mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
+        return mm.flatten()
+
+    def loss_fun(self, x, y):
+        """
+        计算白盒水印嵌入时的损失
+        :param x: 预测值
+        :param y: 实际值
+        :return: 损失
+        """
+        return nn.BCEWithLogitsLoss()(x, y)