import torch.nn as nn import torch from torch.optim import SGD, Adam import torch.nn.functional as F def string2bin(s): binary_representation = ''.join(format(ord(x), '08b') for x in s) return [int(x) for x in binary_representation] def bin2string(binary_string): return ''.join(chr(int(binary_string[i:i + 8], 2)) for i in range(0, len(binary_string), 8)) class Embedding(): def __init__(self, model, code: torch.Tensor, key_path: str = None, l=1, train=True): """ 初始化白盒水印编码器 :param model: 模型定义 :param code: 密钥,转换为Tensor格式 :param key_path: 投影矩阵权重文件保存路径 :param l: 水印编码器loss权重 :param train: 是否是训练环境,默认为True """ super(Embedding, self).__init__() self.key_path = key_path if key_path is not None else './key.pt' self.p = self.get_parameters(model) # self.p = parameters # self.w = nn.Parameter(w, requires_grad=True) # the flatten mean parameters # w = torch.mean(self.p, dim=1).reshape(-1) w = self.flatten_parameters(self.p) self.l = l self.w_init = w.clone().detach() print('Size of embedding parameters:', w.shape) self.opt = Adam(self.p, lr=0.001) self.distribution_ignore = ['train_acc'] self.code = torch.tensor(string2bin( code), dtype=torch.float).cuda() # the embedding code self.code_len = self.code.shape[0] print(f'Code:{self.code} code length:{self.code_len}') # 判断是否为训练环境,如果是测试环境,直接加载投影矩阵,训练环境随机生成X矩阵,并保存至key_path中 if not train: self.load_matrix(key_path) else: self.X_random = torch.randn( (self.code_len, self.w_init.shape[0])).cuda() self.save_matrix() def save_matrix(self): torch.save(self.X_random, './key.pt') def load_matrix(self, path): self.X_random = torch.load(path).cuda() def get_parameters(self, model): conv_list = [] # print(model.modules()) for module in model.modules(): if isinstance(module, nn.Conv2d) and module.out_channels > 100: conv_list.append(module) # 增加模型深度不够深且conv_list长度不够问题的处理 if len(conv_list) == 0: for module in model.modules(): if isinstance(module, nn.Conv2d): conv_list.append(module) # print(conv_list) if len(conv_list) > 11: target = conv_list[10:12] elif len(conv_list) >= 2: target = conv_list[0:2] else: target = conv_list # target = conv_list[0:2] print(f'Embedding target:{target}') # parameters = target.weight parameters = [x.weight for x in target] # [x.requires_grad_(True) for x in parameters] return parameters # add penalty value to loss def add_penalty(self, loss): # print(f'original loss:{loss} ') w = self.flatten_parameters(self.p) prob = self.get_prob(self.X_random, w) penalty = self.loss_fun( prob, self.code) loss += self.l * penalty # print(f'penalty loss:{loss} ') return loss def flatten_parameters(self, parameters): parameter = torch.cat([torch.mean(x, dim=3).reshape(-1) for x in parameters]) return parameter def loss_fun(self, x, y): penalty = F.binary_cross_entropy(x, y) return penalty def decode(self, X, w): prob = self.get_prob(X, w) return torch.where(prob > 0.5, 1, 0) def get_prob(self, X, w): mm = torch.mm(self.X_random, w.reshape((w.shape[0], 1))) return F.sigmoid(mm).flatten() def test(self): w = self.flatten_parameters(self.p) decode = self.decode(self.X_random, w) print(decode.shape) code_string = ''.join([str(x) for x in decode.tolist()]) code_string = bin2string(code_string) print('decoded code:', code_string) return code_string