123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import torch.nn as nn
- import torch
- from torch.optim import SGD, Adam
- import torch.nn.functional as F
- def string2bin(s):
- binary_representation = ''.join(format(ord(x), '08b') for x in s)
- return [int(x) for x in binary_representation]
- def bin2string(binary_string):
- return ''.join(chr(int(binary_string[i:i + 8], 2)) for i in range(0, len(binary_string), 8))
- class Embedding():
- def __init__(self, model, code: torch.Tensor, key_path: str = None, l=1, train=True):
- """
- 初始化白盒水印编码器
- :param model: 模型定义
- :param code: 密钥,转换为Tensor格式
- :param key_path: 投影矩阵权重文件保存路径
- :param l: 水印编码器loss权重
- :param train: 是否是训练环境,默认为True
- """
- super(Embedding, self).__init__()
- self.key_path = key_path if key_path is not None else './key.pt'
- self.p = self.get_parameters(model)
- # self.p = parameters
- # self.w = nn.Parameter(w, requires_grad=True)
- # the flatten mean parameters
- # w = torch.mean(self.p, dim=1).reshape(-1)
- w = self.flatten_parameters(self.p)
- self.l = l
- self.w_init = w.clone().detach()
- print('Size of embedding parameters:', w.shape)
- self.opt = Adam(self.p, lr=0.001)
- self.distribution_ignore = ['train_acc']
- self.code = torch.tensor(string2bin(
- code), dtype=torch.float).cuda() # the embedding code
- self.code_len = self.code.shape[0]
- print(f'Code:{self.code} code length:{self.code_len}')
- # 判断是否为训练环境,如果是测试环境,直接加载投影矩阵,训练环境随机生成X矩阵,并保存至key_path中
- if not train:
- self.load_matrix(key_path)
- else:
- self.X_random = torch.randn(
- (self.code_len, self.w_init.shape[0])).cuda()
- self.save_matrix()
- def save_matrix(self):
- torch.save(self.X_random, './key.pt')
- def load_matrix(self, path):
- self.X_random = torch.load(path).cuda()
- def get_parameters(self, model):
- conv_list = []
- # print(model.modules())
- for module in model.modules():
- if isinstance(module, nn.Conv2d) and module.out_channels > 100:
- conv_list.append(module)
- # 增加模型深度不够深且conv_list长度不够问题的处理
- if len(conv_list) == 0:
- for module in model.modules():
- if isinstance(module, nn.Conv2d):
- conv_list.append(module)
- # print(conv_list)
- if len(conv_list) > 11:
- target = conv_list[10:12]
- elif len(conv_list) >= 2:
- target = conv_list[0:2]
- else:
- target = conv_list
- # target = conv_list[0:2]
- print(f'Embedding target:{target}')
- # parameters = target.weight
- parameters = [x.weight for x in target]
- # [x.requires_grad_(True) for x in parameters]
- return parameters
- # add penalty value to loss
- def add_penalty(self, loss):
- # print(f'original loss:{loss} ')
- w = self.flatten_parameters(self.p)
- prob = self.get_prob(self.X_random, w)
- penalty = self.loss_fun(
- prob, self.code)
- loss += self.l * penalty
- # print(f'penalty loss:{loss} ')
- return loss
- def flatten_parameters(self, parameters):
- parameter = torch.cat([torch.mean(x, dim=3).reshape(-1)
- for x in parameters])
- return parameter
- def loss_fun(self, x, y):
- penalty = F.binary_cross_entropy(x, y)
- return penalty
- def decode(self, X, w):
- prob = self.get_prob(X, w)
- return torch.where(prob > 0.5, 1, 0)
- def get_prob(self, X, w):
- mm = torch.mm(self.X_random, w.reshape((w.shape[0], 1)))
- return F.sigmoid(mm).flatten()
- def test(self):
- w = self.flatten_parameters(self.p)
- decode = self.decode(self.X_random, w)
- print(decode.shape)
- code_string = ''.join([str(x) for x in decode.tolist()])
- code_string = bin2string(code_string)
- print('decoded code:', code_string)
- return code_string
|