123456789101112131415161718192021222324252627282930313233 |
- """
- Created on 2024/5/8
- @author: <NAME>
- @version: 1.0
- @file: model_decoder.py
- @brief 白盒水印解码器
- """
- from typing import List
- import torch
- from torch import nn
- from watermark_codec.tool.str_convertor import bin2string
- from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, get_prob
- class ModelDecoder:
- def __init__(self, layers: List[nn.Conv2d], key_path: str = None, device='cuda'):
- # 判断传入的层是否全部为卷积层
- for layer in layers:
- if not isinstance(layer, nn.Conv2d):
- raise TypeError('传入参数不是卷积层')
- weights = [x.weight for x in layers] # 获取所有卷积层权重
- self.w = flatten_parameters(weights)
- self.x_random = load_tensor(key_path, device)
- self.model = None
- def decode(self):
- prob = get_prob(self.x_random, self.w)
- decode = torch.where(prob > 0.5, 1, 0)
- code_string = ''.join([str(x) for x in decode.tolist()])
- code_string = bin2string(code_string)
- return code_string
|