|
@@ -37,7 +37,7 @@ class WhiteBoxWatermarkProcessDefine:
|
|
|
self.x_random_file = x_random_file
|
|
|
self.public_key = public_key
|
|
|
|
|
|
- def extract_label(self, start, end):
|
|
|
+ def extract_label(self, scope, indices):
|
|
|
import onnx
|
|
|
import numpy as np
|
|
|
"""
|
|
@@ -55,7 +55,11 @@ class WhiteBoxWatermarkProcessDefine:
|
|
|
if initializer.name == weight_name:
|
|
|
# 获取权重数据
|
|
|
weights.append(onnx.numpy_helper.to_array(initializer))
|
|
|
- weights = weights[start:end]
|
|
|
+ if indices:
|
|
|
+ weights = [weights[i] for i in indices if i < len(weights)]
|
|
|
+ else:
|
|
|
+ start, end = scope
|
|
|
+ weights = weights[start:end]
|
|
|
weights = [np.transpose(weight, (2, 3, 1, 0)) for weight in
|
|
|
weights] # 将onnx文件的权重格式由(out_channels, in_channels, kernel_height, kernel_width),转换为(kernel_height, kernel_width, in_channels, out_channels)
|
|
|
x_random = np.load(self.x_random_file)
|
|
@@ -71,14 +75,15 @@ class WhiteBoxWatermarkProcessDefine:
|
|
|
secret_label = ''.join(chr(int(code_string[i:i + 8], 2)) for i in range(0, len(code_string), 8))
|
|
|
return secret_label
|
|
|
|
|
|
- def verify_label(self, start=0, end=3) -> bool:
|
|
|
+ def verify_label(self, scope=(0, 3), indices=None) -> bool:
|
|
|
"""
|
|
|
标签验证
|
|
|
- :param start: 嵌入标签开始卷积层位置,包括起始位
|
|
|
- :param end: 嵌入标签结束卷积层位置,不包括结束位置
|
|
|
+ :param scope:嵌入标签卷积层位置区间,默认值(0,3),包含开始位置,不包含结束位
|
|
|
+ :param indices: 如果指定该参数,会从卷积层指定索引列表中进行权重获取,scope参数无效
|
|
|
:return: 标签验证结果
|
|
|
"""
|
|
|
- secret_label = self.extract_label(start, end)
|
|
|
+ start, end = scope
|
|
|
+ secret_label = self.extract_label(scope, indices)
|
|
|
label_check_result = secret_label_func.verify_secret_label(secret_label=secret_label,
|
|
|
public_key=self.public_key)
|
|
|
return label_check_result
|