Browse Source

修改白盒水印编码器,支持运行设备参数

liyan 1 năm trước cách đây
mục cha
commit
b29a335740
1 tập tin đã thay đổi với 8 bổ sung5 xóa
  1. 8 5
      tool/training_embedding.py

+ 8 - 5
tool/training_embedding.py

@@ -14,14 +14,15 @@ def bin2string(binary_string):
 
 
 class Embedding():
-    def __init__(self, model, code: torch.Tensor, key_path: str = None, l=1, train=True):
+    def __init__(self, model, code, key_path: str = None, l=1, train=True, device='cuda'):
         """
         初始化白盒水印编码器
         :param model: 模型定义
-        :param code: 密钥,转换为Tensor格式
+        :param code: 密钥,字符串格式
         :param key_path: 投影矩阵权重文件保存路径
         :param l: 水印编码器loss权重
         :param train: 是否是训练环境,默认为True
+        :param device: 运行设备,默认为cuda
         """
         super(Embedding, self).__init__()
 
@@ -29,6 +30,8 @@ class Embedding():
 
         self.key_path = key_path if key_path is not None else './key.pt'
 
+        self.device = device
+
         # self.p = parameters
 
         # self.w = nn.Parameter(w, requires_grad=True)
@@ -45,7 +48,7 @@ class Embedding():
         self.opt = Adam(self.p, lr=0.001)
         self.distribution_ignore = ['train_acc']
         self.code = torch.tensor(string2bin(
-            code), dtype=torch.float).cuda()  # the embedding code
+            code), dtype=torch.float).to(self.device)  # the embedding code
         self.code_len = self.code.shape[0]
         print(f'Code:{self.code} code length:{self.code_len}')
 
@@ -54,14 +57,14 @@ class Embedding():
             self.load_matrix(key_path)
         else:
             self.X_random = torch.randn(
-                (self.code_len, self.w_init.shape[0])).cuda()
+                (self.code_len, self.w_init.shape[0])).to(self.device)
             self.save_matrix()
 
     def save_matrix(self):
         torch.save(self.X_random, self.key_path)
 
     def load_matrix(self, path):
-        self.X_random = torch.load(path).cuda()
+        self.X_random = torch.load(path).to(self.device)
 
     def get_parameters(self, model):
         # conv_list = []