model_decoder.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. """
  2. Created on 2024/5/8
  3. @author: <NAME>
  4. @version: 1.0
  5. @file: model_decoder.py
  6. @brief 白盒水印解码器
  7. """
  8. from typing import List
  9. import numpy as np
  10. from mindspore import nn
  11. from watermark_codec.tool.str_convertor import bin2string
  12. from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, get_prob
  13. class ModelDecoder:
  14. def __init__(self, layers: List[nn.Conv2d], key_path: str = None):
  15. # 判断传入的层是否全部为卷积层
  16. for layer in layers:
  17. if not isinstance(layer, nn.Conv2d):
  18. raise TypeError('传入参数不是卷积层')
  19. weights = [x.weight for x in layers] # 获取所有卷积层权重
  20. self.w = flatten_parameters(weights)
  21. self.x_random = load_tensor(key_path)
  22. self.model = None
  23. def decode(self):
  24. prob = get_prob(self.x_random, self.w)
  25. prob = prob.asnumpy()
  26. decode = np.where(prob > 0.5, 1, 0)
  27. code_string = ''.join([str(x) for x in decode.tolist()])
  28. code_string = bin2string(code_string)
  29. return code_string