Browse Source

修改模型加载代码

liyan 11 months ago
parent
commit
62787fabad
2 changed files with 3 additions and 3 deletions
  1. 1 1
      watermark_codec/model_decoder.py
  2. 2 2
      watermark_codec/tool/tensor_deal.py

+ 1 - 1
watermark_codec/model_decoder.py

@@ -15,7 +15,7 @@ from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, ge
 
 
 class ModelDecoder:
-    def __init__(self, layers: List[nn.Conv2d], key_path: str = None, device='cuda'):
+    def __init__(self, layers: List[nn.Conv2d], key_path: str = None, device='cpu'):
         # 判断传入的层是否全部为卷积层
         for layer in layers:
             if not isinstance(layer, nn.Conv2d):

+ 2 - 2
watermark_codec/tool/tensor_deal.py

@@ -17,7 +17,7 @@ def save_tensor(tensor: Tensor, save_path: str):
     torch.save(tensor, save_path)
 
 
-def load_tensor(save_path, device='cuda') -> Tensor:
+def load_tensor(save_path, device='cpu') -> Tensor:
     """
     从指定文件获取张量,并移动到指定的设备上
     :param save_path: pt文件位置
@@ -26,7 +26,7 @@ def load_tensor(save_path, device='cuda') -> Tensor:
     """
     assert save_path.endswith('.pt') or save_path.endswith('.pth'), f"权重保存文件必须以.pt或.pth结尾"
     assert os.path.exists(save_path), f"{save_path}权重文件不存在"
-    return torch.load(save_path).to(device)
+    return torch.load(save_path, map_location=torch.device(device)).to(device)
 
 
 def flatten_parameters(weights: List[Tensor]) -> Tensor: