123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import os
- from typing import List
- import mindspore as ms
- from mindspore import Tensor
- import mindspore.numpy as mnp
- def save_tensor(tensor: Tensor, save_path: str):
- """
- 保存张量至指定文件
- :param tensor:待保存的张量
- :param save_path: 保存位置,例如:/home/secret.pt
- """
- assert save_path.endswith('.ckpt'), "权重保存文件必须以.ckpt结尾"
- save_obj = [{"name": 'x_random', "data": tensor}]
- ms.save_checkpoint(save_obj, save_path)
- def load_tensor(save_path) -> Tensor:
- """
- 从指定文件获取张量,并移动到指定的设备上
- :param save_path: pt文件位置
- :return: 指定张量
- """
- assert save_path.endswith('.ckpt'), "权重保存文件必须以.ckpt结尾"
- assert os.path.exists(save_path), f"{save_path}权重文件不存在"
- save_obj = ms.load_checkpoint(save_path)
- return save_obj['x_random']
- def flatten_parameters(weights: List[Tensor]) -> Tensor:
- """
- 处理传入的卷积层的权重参数
- :param weights: 指定卷积层的权重列表
- :return: 处理完成返回的张量
- """
- # 假设 weights 是一个包含 MindSpore Tensor 的列表
- mean_list = []
- for x in weights:
- mean_x = mnp.mean(x, axis=3).reshape(-1)
- mean_list.append(mean_x)
- concat = ms.ops.Concat(1)
- return concat(mean_list)
- def get_prob(x_random, w) -> Tensor:
- """
- 获取投影矩阵与权重向量的计算结果
- :param x_random: 投影矩阵
- :param w: 权重向量
- :return: 计算记过
- """
- mm = ms.ops.mm(x_random, w.reshape((w.shape[0], 1)))
- return ms.ops.sigmoid(mm).flatten()
- def loss_fun(x, y) -> Tensor:
- """
- 计算白盒水印嵌入时的损失
- :param x: 预测值
- :param y: 实际值
- :return: 损失
- """
- return ms.ops.binary_cross_entropy(x, y)
- if __name__ == '__main__':
- key_path = './secret.ckpt'
- # 生成随机矩阵
- X_random = ms.ops.randn((2, 3))
- print(X_random)
- save_tensor(X_random, key_path) # 保存矩阵至指定位置
- tensor_load = load_tensor(key_path)
- print(tensor_load)
|