|
@@ -1,11 +1,11 @@
|
|
|
+import os
|
|
|
+
|
|
|
import cv2
|
|
|
import tqdm
|
|
|
-# import wandb
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from torch import nn
|
|
|
from torchvision import transforms
|
|
|
-from watermark_codec import ModelEncoder
|
|
|
|
|
|
from block.dataset_get import CustomDataset
|
|
|
from block.val_get import val_get
|
|
@@ -22,7 +22,7 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
for module in model.modules():
|
|
|
if isinstance(module, nn.Conv2d):
|
|
|
conv_list.append(module)
|
|
|
- conv_list = conv_list[0:2]
|
|
|
+ conv_list = conv_list[1:3]
|
|
|
model_dict['enc_layers'] = conv_list # 将加密层保存至权重文件中
|
|
|
encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=args.key_path, device=args.device)
|
|
|
|
|
@@ -169,3 +169,82 @@ def train_embed(args, model_dict, loss, secret):
|
|
|
# })
|
|
|
# args.wandb_run.log(wandb_log)
|
|
|
torch.distributed.barrier() if args.distributed else None # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
|
|
|
+
|
|
|
+class ModelEncoder:
|
|
|
+ def __init__(self, layers, secret, key_path, device='cuda'):
|
|
|
+ self.device = device
|
|
|
+ self.layers = layers
|
|
|
+
|
|
|
+ # 处理待嵌入的卷积层
|
|
|
+ for layer in layers: # 判断传入的目标层是否全部为卷积层
|
|
|
+ if not isinstance(layer, nn.Conv2d):
|
|
|
+ raise TypeError('传入参数不是卷积层')
|
|
|
+ weights = [x.weight for x in layers]
|
|
|
+ weights = [weight.permute(2, 3, 1, 0) for weight in weights]
|
|
|
+ w = self.flatten_parameters(weights)
|
|
|
+ w_init = w.clone().detach()
|
|
|
+ print('Size of embedding parameters:', w.shape)
|
|
|
+
|
|
|
+ # 对密钥进行处理
|
|
|
+ self.secret = torch.tensor(self.string2bin(secret), dtype=torch.float).to(self.device) # the embedding code
|
|
|
+ self.secret_len = self.secret.shape[0]
|
|
|
+ print(f'Secret:{self.secret} secret length:{self.secret_len}')
|
|
|
+
|
|
|
+ # 生成随机的投影矩阵
|
|
|
+ self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device)
|
|
|
+ self.save_tensor(self.X_random, key_path) # 保存投影矩阵至指定位置
|
|
|
+
|
|
|
+ def get_embeder_loss(self):
|
|
|
+ """
|
|
|
+ 获取水印嵌入损失
|
|
|
+ :return: 水印嵌入的损失值
|
|
|
+ """
|
|
|
+ weights = [x.weight for x in self.layers]
|
|
|
+ weights = [weight.permute(2, 3, 1, 0) for weight in weights] # 使用pytorch框架时,要调整坐标顺序,保持与tensorflow版本一致
|
|
|
+ w = self.flatten_parameters(weights)
|
|
|
+ prob = self.get_prob(self.X_random, w)
|
|
|
+ penalty = self.loss_fun(prob, self.secret)
|
|
|
+ return penalty
|
|
|
+
|
|
|
+ def string2bin(self, s):
|
|
|
+ binary_representation = ''.join(format(ord(x), '08b') for x in s)
|
|
|
+ return [int(x) for x in binary_representation]
|
|
|
+
|
|
|
+ def save_tensor(self, tensor, save_path):
|
|
|
+ """
|
|
|
+ 保存张量至指定文件
|
|
|
+ :param tensor:待保存的张量
|
|
|
+ :param save_path: 保存位置,例如:/home/secret.pt
|
|
|
+ """
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
+ tensor = tensor.cpu()
|
|
|
+ numpy_array = tensor.numpy()
|
|
|
+ np.save(save_path, numpy_array)
|
|
|
+
|
|
|
+ def flatten_parameters(self, weights):
|
|
|
+ """
|
|
|
+ 处理传入的卷积层的权重参数
|
|
|
+ :param weights: 指定卷积层的权重列表
|
|
|
+ :return: 处理完成返回的张量
|
|
|
+ """
|
|
|
+ return torch.cat([torch.mean(x, dim=3).reshape(-1)
|
|
|
+ for x in weights])
|
|
|
+
|
|
|
+ def get_prob(self, x_random, w):
|
|
|
+ """
|
|
|
+ 获取投影矩阵与权重向量的计算结果
|
|
|
+ :param x_random: 投影矩阵
|
|
|
+ :param w: 权重向量
|
|
|
+ :return: 计算记过
|
|
|
+ """
|
|
|
+ mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
|
|
|
+ return mm.flatten()
|
|
|
+
|
|
|
+ def loss_fun(self, x, y):
|
|
|
+ """
|
|
|
+ 计算白盒水印嵌入时的损失
|
|
|
+ :param x: 预测值
|
|
|
+ :param y: 实际值
|
|
|
+ :return: 损失
|
|
|
+ """
|
|
|
+ return nn.BCEWithLogitsLoss()(x, y)
|