tensor_deal.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import os
  2. from typing import List
  3. import torch
  4. from torch import Tensor
  5. import torch.nn.functional as F
  6. def save_tensor(tensor: Tensor, save_path: str):
  7. """
  8. 保存张量至指定文件
  9. :param tensor:待保存的张量
  10. :param save_path: 保存位置,例如:/home/secret.pt
  11. """
  12. assert save_path.endswith('.pt') or save_path.endswith('.pth'), "权重保存文件必须以.pt或.pth结尾"
  13. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  14. torch.save(tensor, save_path)
  15. def load_tensor(save_path, device='cuda') -> Tensor:
  16. """
  17. 从指定文件获取张量,并移动到指定的设备上
  18. :param save_path: pt文件位置
  19. :param device: 加载至指定设备,默认为cuda
  20. :return: 指定张量
  21. """
  22. assert save_path.endswith('.pt') or save_path.endswith('.pth'), f"权重保存文件必须以.pt或.pth结尾"
  23. assert os.path.exists(save_path), f"{save_path}权重文件不存在"
  24. return torch.load(save_path).to(device)
  25. def flatten_parameters(weights: List[Tensor]) -> Tensor:
  26. """
  27. 处理传入的卷积层的权重参数
  28. :param weights: 指定卷积层的权重列表
  29. :return: 处理完成返回的张量
  30. """
  31. return torch.cat([torch.mean(x, dim=3).reshape(-1)
  32. for x in weights])
  33. def get_prob(x_random, w) -> Tensor:
  34. """
  35. 获取投影矩阵与权重向量的计算结果
  36. :param x_random: 投影矩阵
  37. :param w: 权重向量
  38. :return: 计算记过
  39. """
  40. mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
  41. return F.sigmoid(mm).flatten()
  42. def loss_fun(x, y) -> Tensor:
  43. """
  44. 计算白盒水印嵌入时的损失
  45. :param x: 预测值
  46. :param y: 实际值
  47. :return: 损失
  48. """
  49. return F.binary_cross_entropy(x, y)
  50. if __name__ == '__main__':
  51. key_path = './secret.pt'
  52. device = 'cuda'
  53. # 生成随机矩阵
  54. X_random = torch.randn((2, 3)).to(device)
  55. save_tensor(X_random, key_path) # 保存矩阵至指定位置
  56. tensor_load = load_tensor(key_path, device)
  57. print(tensor_load)