|
@@ -17,7 +17,7 @@ def save_tensor(tensor: Tensor, save_path: str):
|
|
|
torch.save(tensor, save_path)
|
|
|
|
|
|
|
|
|
-def load_tensor(save_path, device='cuda') -> Tensor:
|
|
|
+def load_tensor(save_path, device='cpu') -> Tensor:
|
|
|
"""
|
|
|
从指定文件获取张量,并移动到指定的设备上
|
|
|
:param save_path: pt文件位置
|
|
@@ -26,7 +26,7 @@ def load_tensor(save_path, device='cuda') -> Tensor:
|
|
|
"""
|
|
|
assert save_path.endswith('.pt') or save_path.endswith('.pth'), f"权重保存文件必须以.pt或.pth结尾"
|
|
|
assert os.path.exists(save_path), f"{save_path}权重文件不存在"
|
|
|
- return torch.load(save_path).to(device)
|
|
|
+ return torch.load(save_path, map_location=torch.device(device)).to(device)
|
|
|
|
|
|
|
|
|
def flatten_parameters(weights: List[Tensor]) -> Tensor:
|