""" Created on 2024/5/8 @author: @version: 1.0 @file: model_decoder.py @brief 白盒水印解码器 """ from typing import List import torch from torch import nn from watermark_codec.tool.str_convertor import bin2string from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, get_prob class ModelDecoder: def __init__(self, layers: List[nn.Conv2d], key_path: str = None, device='cpu'): # 判断传入的层是否全部为卷积层 for layer in layers: if not isinstance(layer, nn.Conv2d): raise TypeError('传入参数不是卷积层') weights = [x.weight for x in layers] # 获取所有卷积层权重 self.w = flatten_parameters(weights) self.x_random = load_tensor(key_path, device) self.model = None def decode(self): prob = get_prob(self.x_random, self.w) decode = torch.where(prob > 0.5, 1, 0) code_string = ''.join([str(x) for x in decode.tolist()]) code_string = bin2string(code_string) return code_string