|
@@ -56,7 +56,10 @@ def modify_model_project(secret_label: str, project_dir: str, public_key: str):
|
|
|
print(f'Secret:{self.secret} secret length:{self.secret_len}')
|
|
|
|
|
|
# 生成随机的投影矩阵
|
|
|
- self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device)
|
|
|
+ if os.path.exists(key_path):
|
|
|
+ self.X_random = np.load(key_path)
|
|
|
+ else:
|
|
|
+ self.X_random = torch.randn((self.secret_len, w_init.shape[0])).to(self.device)
|
|
|
self.save_tensor(self.X_random, key_path) # 保存投影矩阵至指定位置
|
|
|
|
|
|
def get_embeder_loss(self):
|