12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import os
- from typing import List
- import torch
- from torch import Tensor
- import torch.nn.functional as F
- def save_tensor(tensor: Tensor, save_path: str):
- """
- 保存张量至指定文件
- :param tensor:待保存的张量
- :param save_path: 保存位置,例如:/home/secret.pt
- """
- assert save_path.endswith('.pt') or save_path.endswith('.pth'), "权重保存文件必须以.pt或.pth结尾"
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- torch.save(tensor, save_path)
- def load_tensor(save_path, device='cuda') -> Tensor:
- """
- 从指定文件获取张量,并移动到指定的设备上
- :param save_path: pt文件位置
- :param device: 加载至指定设备,默认为cuda
- :return: 指定张量
- """
- assert save_path.endswith('.pt') or save_path.endswith('.pth'), f"权重保存文件必须以.pt或.pth结尾"
- assert os.path.exists(save_path), f"{save_path}权重文件不存在"
- return torch.load(save_path).to(device)
- def flatten_parameters(weights: List[Tensor]) -> Tensor:
- """
- 处理传入的卷积层的权重参数
- :param weights: 指定卷积层的权重列表
- :return: 处理完成返回的张量
- """
- return torch.cat([torch.mean(x, dim=3).reshape(-1)
- for x in weights])
- def get_prob(x_random, w) -> Tensor:
- """
- 获取投影矩阵与权重向量的计算结果
- :param x_random: 投影矩阵
- :param w: 权重向量
- :return: 计算记过
- """
- mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
- return F.sigmoid(mm).flatten()
- def loss_fun(x, y) -> Tensor:
- """
- 计算白盒水印嵌入时的损失
- :param x: 预测值
- :param y: 实际值
- :return: 损失
- """
- return F.binary_cross_entropy(x, y)
- if __name__ == '__main__':
- key_path = './secret.pt'
- device = 'cuda'
- # 生成随机矩阵
- X_random = torch.randn((2, 3)).to(device)
- save_tensor(X_random, key_path) # 保存矩阵至指定位置
- tensor_load = load_tensor(key_path, device)
- print(tensor_load)
|