|
@@ -14,14 +14,15 @@ def bin2string(binary_string):
|
|
|
|
|
|
|
|
|
class Embedding():
|
|
|
- def __init__(self, model, code: torch.Tensor, key_path: str = None, l=1, train=True):
|
|
|
+ def __init__(self, model, code, key_path: str = None, l=1, train=True, device='cuda'):
|
|
|
"""
|
|
|
初始化白盒水印编码器
|
|
|
:param model: 模型定义
|
|
|
- :param code: 密钥,转换为Tensor格式
|
|
|
+ :param code: 密钥,字符串格式
|
|
|
:param key_path: 投影矩阵权重文件保存路径
|
|
|
:param l: 水印编码器loss权重
|
|
|
:param train: 是否是训练环境,默认为True
|
|
|
+ :param device: 运行设备,默认为cuda
|
|
|
"""
|
|
|
super(Embedding, self).__init__()
|
|
|
|
|
@@ -29,6 +30,8 @@ class Embedding():
|
|
|
|
|
|
self.key_path = key_path if key_path is not None else './key.pt'
|
|
|
|
|
|
+ self.device = device
|
|
|
+
|
|
|
# self.p = parameters
|
|
|
|
|
|
# self.w = nn.Parameter(w, requires_grad=True)
|
|
@@ -45,7 +48,7 @@ class Embedding():
|
|
|
self.opt = Adam(self.p, lr=0.001)
|
|
|
self.distribution_ignore = ['train_acc']
|
|
|
self.code = torch.tensor(string2bin(
|
|
|
- code), dtype=torch.float).cuda() # the embedding code
|
|
|
+ code), dtype=torch.float).to(self.device) # the embedding code
|
|
|
self.code_len = self.code.shape[0]
|
|
|
print(f'Code:{self.code} code length:{self.code_len}')
|
|
|
|
|
@@ -54,14 +57,14 @@ class Embedding():
|
|
|
self.load_matrix(key_path)
|
|
|
else:
|
|
|
self.X_random = torch.randn(
|
|
|
- (self.code_len, self.w_init.shape[0])).cuda()
|
|
|
+ (self.code_len, self.w_init.shape[0])).to(self.device)
|
|
|
self.save_matrix()
|
|
|
|
|
|
def save_matrix(self):
|
|
|
torch.save(self.X_random, self.key_path)
|
|
|
|
|
|
def load_matrix(self, path):
|
|
|
- self.X_random = torch.load(path).cuda()
|
|
|
+ self.X_random = torch.load(path).to(self.device)
|
|
|
|
|
|
def get_parameters(self, model):
|
|
|
# conv_list = []
|