소스 검색

Merge branch 'refs/heads/master' into test

liyan 11 달 전
부모
커밋
82073bd659
3개의 변경된 파일4개의 추가작업 그리고 3개의 파일을 삭제
  1. 1 0
      README.md
  2. 1 1
      watermark_codec/model_decoder.py
  3. 2 2
      watermark_codec/tool/tensor_deal.py

+ 1 - 0
README.md

@@ -4,6 +4,7 @@
 ## 分支说明
 - `master`分支只包含项目打包配置和白盒水印编解码器源码
 - `test`分支在`master`分支基础上添加了测试模型、训练代码、验证代码
+- `mindspore`分支使用mindspore框架重新实现白盒水印嵌入代码,并使用此框架重写测试模型、训练代码、验证代码
 
 ## 文件组成
 ```text

+ 1 - 1
watermark_codec/model_decoder.py

@@ -15,7 +15,7 @@ from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, ge
 
 
 class ModelDecoder:
-    def __init__(self, layers: List[nn.Conv2d], key_path: str = None, device='cuda'):
+    def __init__(self, layers: List[nn.Conv2d], key_path: str = None, device='cpu'):
         # 判断传入的层是否全部为卷积层
         for layer in layers:
             if not isinstance(layer, nn.Conv2d):

+ 2 - 2
watermark_codec/tool/tensor_deal.py

@@ -17,7 +17,7 @@ def save_tensor(tensor: Tensor, save_path: str):
     torch.save(tensor, save_path)
 
 
-def load_tensor(save_path, device='cuda') -> Tensor:
+def load_tensor(save_path, device='cpu') -> Tensor:
     """
     从指定文件获取张量,并移动到指定的设备上
     :param save_path: pt文件位置
@@ -26,7 +26,7 @@ def load_tensor(save_path, device='cuda') -> Tensor:
     """
     assert save_path.endswith('.pt') or save_path.endswith('.pth'), f"权重保存文件必须以.pt或.pth结尾"
     assert os.path.exists(save_path), f"{save_path}权重文件不存在"
-    return torch.load(save_path).to(device)
+    return torch.load(save_path, map_location=torch.device(device)).to(device)
 
 
 def flatten_parameters(weights: List[Tensor]) -> Tensor: