tensor_deal.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. from typing import List
  3. import mindspore as ms
  4. from mindspore import Tensor
  5. import mindspore.numpy as mnp
  6. def save_tensor(tensor: Tensor, save_path: str):
  7. """
  8. 保存张量至指定文件
  9. :param tensor:待保存的张量
  10. :param save_path: 保存位置,例如:/home/secret.pt
  11. """
  12. assert save_path.endswith('.ckpt'), "权重保存文件必须以.ckpt结尾"
  13. save_obj = [{"name": 'x_random', "data": tensor}]
  14. ms.save_checkpoint(save_obj, save_path)
  15. def load_tensor(save_path) -> Tensor:
  16. """
  17. 从指定文件获取张量,并移动到指定的设备上
  18. :param save_path: pt文件位置
  19. :return: 指定张量
  20. """
  21. assert save_path.endswith('.ckpt'), "权重保存文件必须以.ckpt结尾"
  22. assert os.path.exists(save_path), f"{save_path}权重文件不存在"
  23. save_obj = ms.load_checkpoint(save_path)
  24. return save_obj['x_random']
  25. def flatten_parameters(weights: List[Tensor]) -> Tensor:
  26. """
  27. 处理传入的卷积层的权重参数
  28. :param weights: 指定卷积层的权重列表
  29. :return: 处理完成返回的张量
  30. """
  31. # 假设 weights 是一个包含 MindSpore Tensor 的列表
  32. mean_list = []
  33. for x in weights:
  34. mean_x = mnp.mean(x, axis=3).reshape(-1)
  35. mean_list.append(mean_x)
  36. return ms.ops.concat(mean_list)
  37. def get_prob(x_random, w) -> Tensor:
  38. """
  39. 获取投影矩阵与权重向量的计算结果
  40. :param x_random: 投影矩阵
  41. :param w: 权重向量
  42. :return: 计算记过
  43. """
  44. mm = ms.ops.mm(x_random, w.reshape((w.shape[0], 1)))
  45. return ms.ops.sigmoid(mm).flatten()
  46. def loss_fun(x, y) -> Tensor:
  47. """
  48. 计算白盒水印嵌入时的损失
  49. :param x: 预测值
  50. :param y: 实际值
  51. :return: 损失
  52. """
  53. return ms.ops.binary_cross_entropy(x, y)
  54. if __name__ == '__main__':
  55. ms.set_context(device_target="GPU")
  56. key_path = './secret.ckpt'
  57. # 生成随机矩阵
  58. X_random = ms.ops.randn((2, 3))
  59. print(X_random)
  60. save_tensor(X_random, key_path) # 保存矩阵至指定位置
  61. tensor_load = load_tensor(key_path)
  62. print(tensor_load)