import os from typing import List import torch from torch import Tensor import torch.nn.functional as F def save_tensor(tensor: Tensor, save_path: str): """ 保存张量至指定文件 :param tensor:待保存的张量 :param save_path: 保存位置,例如:/home/secret.pt """ assert save_path.endswith('.pt') or save_path.endswith('.pth'), "权重保存文件必须以.pt或.pth结尾" os.makedirs(os.path.dirname(save_path), exist_ok=True) torch.save(tensor, save_path) def load_tensor(save_path, device='cpu') -> Tensor: """ 从指定文件获取张量,并移动到指定的设备上 :param save_path: pt文件位置 :param device: 加载至指定设备,默认为cuda :return: 指定张量 """ assert save_path.endswith('.pt') or save_path.endswith('.pth'), f"权重保存文件必须以.pt或.pth结尾" assert os.path.exists(save_path), f"{save_path}权重文件不存在" return torch.load(save_path, map_location=torch.device(device)).to(device) def flatten_parameters(weights: List[Tensor]) -> Tensor: """ 处理传入的卷积层的权重参数 :param weights: 指定卷积层的权重列表 :return: 处理完成返回的张量 """ return torch.cat([torch.mean(x, dim=3).reshape(-1) for x in weights]) def get_prob(x_random, w) -> Tensor: """ 获取投影矩阵与权重向量的计算结果 :param x_random: 投影矩阵 :param w: 权重向量 :return: 计算记过 """ mm = torch.mm(x_random, w.reshape((w.shape[0], 1))) return F.sigmoid(mm).flatten() def loss_fun(x, y) -> Tensor: """ 计算白盒水印嵌入时的损失 :param x: 预测值 :param y: 实际值 :return: 损失 """ return F.binary_cross_entropy(x, y) if __name__ == '__main__': key_path = './secret.pt' device = 'cuda' # 生成随机矩阵 X_random = torch.randn((2, 3)).to(device) save_tensor(X_random, key_path) # 保存矩阵至指定位置 tensor_load = load_tensor(key_path, device) print(tensor_load)