liyan 85312b7bf2 新增签名验签工具及软算法模拟 | 3 mesi fa | |
---|---|---|
watermark_codec | 3 mesi fa | |
.gitignore | 6 mesi fa | |
MANIFEST.in | 6 mesi fa | |
README.md | 4 mesi fa | |
setup.py | 6 mesi fa |
提供模型训练嵌入白盒水印和从已经嵌入白盒水印的模型中提取水印的功能
master
分支只包含项目打包配置和白盒水印编解码器源码test
分支在master
分支基础上添加了测试模型、训练代码、验证代码mindspore
分支使用mindspore框架重新实现白盒水印嵌入代码,并使用此框架重写测试模型、训练代码、验证代码watermark_codec_pkg
├── MANIFEST.in # 打包配置文件
├── README.md # 项目说明文件
├── setup.py # 项目打包配置文件
└── watermark_codec # 白盒水印编解码器源码
├── __init__.py
├── model_decoder.py # 白盒水印解码器
├── model_encoder.py # 白盒水印编码器
└── tool # 工具脚本文件夹
├── __init__.py
├── str_convertor.py # 字符串转换
└── tensor_deal.py # 张量处理
白盒水印编码器使用
import torch.nn as nn
from model.Alexnet import Alexnet
from watermark_codec import ModelEncoder
from watermark_codec.tool import secret_func
# 创建AlexNet模型实例
model = Alexnet(3, 10, 32).to('cuda')
# 获取模型中待嵌入的卷积层
conv_list = []
for module in model.modules():
if isinstance(module, nn.Conv2d):
conv_list.append(module)
conv_list = conv_list[0:2]
secret = secret_func.get_secret(512) # 获取密钥
# 初始化模型水印编码器
encoder = ModelEncoder(layers=conv_list, secret=secret, key_path='watermark_codec/run/train/key.pt', device='cuda')
# ------------------------ 训练过程 -------------------------------#
# 实际应用只调用get_loss修改原损失即可
loss = encoder.get_loss(loss) # loss变量为原模型损失
白盒水印解码器使用
# 测试水印嵌入
import torch
from torch import nn
from model.Alexnet import Alexnet
from watermark_codec import ModelDecoder
model_path = './run/train/alex_net.pt'
key_path = './run/train/key.pt'
device = 'cuda'
# 从权重文件中加载模型
model = Alexnet(3, 10, 32).to(device)
model.load_state_dict(torch.load(model_path))
# 获取模型中待嵌入的卷积层
conv_list = []
for module in model.modules():
if isinstance(module, nn.Conv2d):
conv_list.append(module)
conv_list = conv_list[0:2]
# 初始化白盒水印解码器
decoder = ModelDecoder(layers=conv_list, key_path=key_path, device=device) # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
secret_extract = decoder.decode() # 提取密码标签
python setup.py sdist
项目目录会生成dist
目录,其中watermark_codec-1.0.tar.gz
即为发布包
pip install watermark_codec-1.0.tar.gz