|
@@ -8,19 +8,28 @@ Created on 2024/5/8
|
|
from typing import List
|
|
from typing import List
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
-from mindspore import nn
|
|
|
|
|
|
+from mindspore import nn, Tensor
|
|
|
|
|
|
from watermark_codec.tool.str_convertor import bin2string
|
|
from watermark_codec.tool.str_convertor import bin2string
|
|
from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, get_prob
|
|
from watermark_codec.tool.tensor_deal import load_tensor, flatten_parameters, get_prob
|
|
|
|
|
|
|
|
|
|
class ModelDecoder:
|
|
class ModelDecoder:
|
|
- def __init__(self, layers: List[nn.Conv2d], key_path: str = None):
|
|
|
|
|
|
+ def __init__(self, layers: List[nn.Conv2d] = None, weights: List[Tensor] = None, key_path: str = None):
|
|
|
|
+ """
|
|
|
|
+ 初始化白盒模型解码器
|
|
|
|
+ :param layers: 模型嵌入白盒水印的加密层
|
|
|
|
+ :param weights: 模型嵌入白盒水印的加密层的权重,当weights不为None时,layers参数无效
|
|
|
|
+ :param key_path: 投影矩阵导出文件位置,加载投影矩阵使用
|
|
|
|
+ """
|
|
# 判断传入的层是否全部为卷积层
|
|
# 判断传入的层是否全部为卷积层
|
|
- for layer in layers:
|
|
|
|
- if not isinstance(layer, nn.Conv2d):
|
|
|
|
- raise TypeError('传入参数不是卷积层')
|
|
|
|
- weights = [x.weight for x in layers] # 获取所有卷积层权重
|
|
|
|
|
|
+ if layers is None and weights is None:
|
|
|
|
+ raise RuntimeError('layers和weights不可同时为空')
|
|
|
|
+ if weights is None:
|
|
|
|
+ for layer in layers:
|
|
|
|
+ if not isinstance(layer, nn.Conv2d):
|
|
|
|
+ raise TypeError('传入参数不是卷积层')
|
|
|
|
+ weights = [x.weight for x in layers] # 获取所有卷积层权重
|
|
self.w = flatten_parameters(weights)
|
|
self.w = flatten_parameters(weights)
|
|
self.x_random = load_tensor(key_path)
|
|
self.x_random = load_tensor(key_path)
|
|
self.model = None
|
|
self.model = None
|