model_decoder.py 1.0 KB

123456789101112131415161718192021222324252627282930313233
  1. """
  2. Created on 2024/5/8
  3. @author: <NAME>
  4. @version: 1.0
  5. @file: model_decoder.py
  6. @brief 白盒水印解码器
  7. """
  8. from typing import List
  9. import torch
  10. from torch import nn
  11. from watermark_codec.tool.str_convertor import bin2string
  12. from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, get_prob
  13. class ModelDecoder:
  14. def __init__(self, layers: List[nn.Conv2d], key_path: str = None, device='cuda'):
  15. # 判断传入的层是否全部为卷积层
  16. for layer in layers:
  17. if not isinstance(layer, nn.Conv2d):
  18. raise TypeError('传入参数不是卷积层')
  19. weights = [x.weight for x in layers] # 获取所有卷积层权重
  20. self.w = flatten_parameters(weights)
  21. self.x_random = load_tensor(key_path, device)
  22. self.model = None
  23. def decode(self):
  24. prob = get_prob(self.x_random, self.w)
  25. decode = torch.where(prob > 0.5, 1, 0)
  26. code_string = ''.join([str(x) for x in decode.tolist()])
  27. code_string = bin2string(code_string)
  28. return code_string