|
@@ -1,9 +1,9 @@
|
|
import os
|
|
import os
|
|
from typing import List
|
|
from typing import List
|
|
|
|
|
|
-import torch
|
|
|
|
-from torch import Tensor
|
|
|
|
-import torch.nn.functional as F
|
|
|
|
|
|
+import mindspore as ms
|
|
|
|
+from mindspore import Tensor
|
|
|
|
+import mindspore.numpy as mnp
|
|
|
|
|
|
|
|
|
|
def save_tensor(tensor: Tensor, save_path: str):
|
|
def save_tensor(tensor: Tensor, save_path: str):
|
|
@@ -12,21 +12,21 @@ def save_tensor(tensor: Tensor, save_path: str):
|
|
:param tensor:待保存的张量
|
|
:param tensor:待保存的张量
|
|
:param save_path: 保存位置,例如:/home/secret.pt
|
|
:param save_path: 保存位置,例如:/home/secret.pt
|
|
"""
|
|
"""
|
|
- assert save_path.endswith('.pt') or save_path.endswith('.pth'), "权重保存文件必须以.pt或.pth结尾"
|
|
|
|
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
- torch.save(tensor, save_path)
|
|
|
|
|
|
+ assert save_path.endswith('.ckpt'), "权重保存文件必须以.ckpt结尾"
|
|
|
|
+ save_obj = [{"name": 'x_random', "data": tensor}]
|
|
|
|
+ ms.save_checkpoint(save_obj, save_path)
|
|
|
|
|
|
|
|
|
|
-def load_tensor(save_path, device='cuda') -> Tensor:
|
|
|
|
|
|
+def load_tensor(save_path) -> Tensor:
|
|
"""
|
|
"""
|
|
从指定文件获取张量,并移动到指定的设备上
|
|
从指定文件获取张量,并移动到指定的设备上
|
|
:param save_path: pt文件位置
|
|
:param save_path: pt文件位置
|
|
- :param device: 加载至指定设备,默认为cuda
|
|
|
|
:return: 指定张量
|
|
:return: 指定张量
|
|
"""
|
|
"""
|
|
- assert save_path.endswith('.pt') or save_path.endswith('.pth'), f"权重保存文件必须以.pt或.pth结尾"
|
|
|
|
|
|
+ assert save_path.endswith('.ckpt'), "权重保存文件必须以.ckpt结尾"
|
|
assert os.path.exists(save_path), f"{save_path}权重文件不存在"
|
|
assert os.path.exists(save_path), f"{save_path}权重文件不存在"
|
|
- return torch.load(save_path).to(device)
|
|
|
|
|
|
+ save_obj = ms.load_checkpoint(save_path)
|
|
|
|
+ return save_obj['x_random']
|
|
|
|
|
|
|
|
|
|
def flatten_parameters(weights: List[Tensor]) -> Tensor:
|
|
def flatten_parameters(weights: List[Tensor]) -> Tensor:
|
|
@@ -35,8 +35,14 @@ def flatten_parameters(weights: List[Tensor]) -> Tensor:
|
|
:param weights: 指定卷积层的权重列表
|
|
:param weights: 指定卷积层的权重列表
|
|
:return: 处理完成返回的张量
|
|
:return: 处理完成返回的张量
|
|
"""
|
|
"""
|
|
- return torch.cat([torch.mean(x, dim=3).reshape(-1)
|
|
|
|
- for x in weights])
|
|
|
|
|
|
+ # 假设 weights 是一个包含 MindSpore Tensor 的列表
|
|
|
|
+ mean_list = []
|
|
|
|
+ for x in weights:
|
|
|
|
+ mean_x = mnp.mean(x, axis=3).reshape(-1)
|
|
|
|
+ mean_list.append(mean_x)
|
|
|
|
+
|
|
|
|
+ concat = ms.ops.Concat(1)
|
|
|
|
+ return concat(mean_list)
|
|
|
|
|
|
|
|
|
|
def get_prob(x_random, w) -> Tensor:
|
|
def get_prob(x_random, w) -> Tensor:
|
|
@@ -46,8 +52,8 @@ def get_prob(x_random, w) -> Tensor:
|
|
:param w: 权重向量
|
|
:param w: 权重向量
|
|
:return: 计算记过
|
|
:return: 计算记过
|
|
"""
|
|
"""
|
|
- mm = torch.mm(x_random, w.reshape((w.shape[0], 1)))
|
|
|
|
- return F.sigmoid(mm).flatten()
|
|
|
|
|
|
+ mm = ms.ops.mm(x_random, w.reshape((w.shape[0], 1)))
|
|
|
|
+ return ms.ops.sigmoid(mm).flatten()
|
|
|
|
|
|
|
|
|
|
def loss_fun(x, y) -> Tensor:
|
|
def loss_fun(x, y) -> Tensor:
|
|
@@ -57,14 +63,13 @@ def loss_fun(x, y) -> Tensor:
|
|
:param y: 实际值
|
|
:param y: 实际值
|
|
:return: 损失
|
|
:return: 损失
|
|
"""
|
|
"""
|
|
- return F.binary_cross_entropy(x, y)
|
|
|
|
-
|
|
|
|
|
|
+ return ms.ops.binary_cross_entropy(x, y)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
- key_path = './secret.pt'
|
|
|
|
- device = 'cuda'
|
|
|
|
|
|
+ key_path = './secret.ckpt'
|
|
# 生成随机矩阵
|
|
# 生成随机矩阵
|
|
- X_random = torch.randn((2, 3)).to(device)
|
|
|
|
|
|
+ X_random = ms.ops.randn((2, 3))
|
|
|
|
+ print(X_random)
|
|
save_tensor(X_random, key_path) # 保存矩阵至指定位置
|
|
save_tensor(X_random, key_path) # 保存矩阵至指定位置
|
|
- tensor_load = load_tensor(key_path, device)
|
|
|
|
|
|
+ tensor_load = load_tensor(key_path)
|
|
print(tensor_load)
|
|
print(tensor_load)
|