|
@@ -5,6 +5,8 @@ from mindspore.dataset.vision import transforms
|
|
|
import tqdm
|
|
|
|
|
|
from tests.model.AlexNet import AlexNet
|
|
|
+from tests.secret_func import get_secret
|
|
|
+from watermark_codec import ModelEncoder
|
|
|
|
|
|
train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
|
|
|
test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
|
|
@@ -32,6 +34,15 @@ test_dataset = datapipe(test_dataset, batch_size)
|
|
|
model = AlexNet(input_channels=3, output_num=10, input_size=32)
|
|
|
print(model)
|
|
|
|
|
|
+# init white_box watermark encoder
|
|
|
+layers = []
|
|
|
+secret = get_secret(512)
|
|
|
+key_path = './run/train/key.ckpt'
|
|
|
+for name, layer in model.cells_and_names():
|
|
|
+ if isinstance (layer, nn.Conv2d):
|
|
|
+ layers.append(layer)
|
|
|
+model_encoder = ModelEncoder(layers=layers[0:2], secret=secret, key_path=key_path)
|
|
|
+
|
|
|
# Instantiate loss function and optimizer
|
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
optimizer = nn.Adam(model.trainable_params(), 1e-2)
|
|
@@ -40,8 +51,9 @@ optimizer = nn.Adam(model.trainable_params(), 1e-2)
|
|
|
# 1. Define forward function
|
|
|
def forward_fn(data, label):
|
|
|
logits = model(data)
|
|
|
- loss = loss_fn(logits, label)
|
|
|
- return loss, logits
|
|
|
+ loss_embed = model_encoder.get_embeder_loss()
|
|
|
+ loss = loss_fn(logits, label) + loss_embed
|
|
|
+ return loss, loss_embed, logits
|
|
|
|
|
|
|
|
|
# 2. Get gradient function
|
|
@@ -50,9 +62,9 @@ grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_a
|
|
|
|
|
|
# 3. Define function of one-step training
|
|
|
def train_step(data, label):
|
|
|
- (loss, _), grads = grad_fn(data, label)
|
|
|
+ (loss, loss_embed, _), grads = grad_fn(data, label)
|
|
|
optimizer(grads)
|
|
|
- return loss
|
|
|
+ return loss, loss_embed
|
|
|
|
|
|
|
|
|
def train(model, dataset):
|
|
@@ -60,8 +72,8 @@ def train(model, dataset):
|
|
|
tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch')
|
|
|
model.set_train()
|
|
|
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
|
|
|
- loss = train_step(data, label)
|
|
|
- tqdm_show.set_postfix({'train_loss': loss.asnumpy()}) # 添加显示
|
|
|
+ loss, loss_embed = train_step(data, label)
|
|
|
+ tqdm_show.set_postfix({'train_loss': loss.asnumpy(), 'wm_loss': loss_embed.asnumpy()}) # 添加显示
|
|
|
tqdm_show.update(1) # 更新进度条
|
|
|
tqdm_show.close()
|
|
|
|
|
@@ -74,8 +86,9 @@ def test(model, dataset, loss_fn):
|
|
|
for data, label in dataset.create_tuple_iterator():
|
|
|
pred = model(data)
|
|
|
total += len(data)
|
|
|
- loss = loss_fn(pred, label).asnumpy()
|
|
|
- tqdm_show.set_postfix({'val_loss': loss}) # 添加显示
|
|
|
+ loss_embed = model_encoder.get_embeder_loss().asnumpy()
|
|
|
+ loss = loss_fn(pred, label).asnumpy() + loss_embed
|
|
|
+ tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed.asnumpy()}) # 添加显示
|
|
|
tqdm_show.update(1) # 更新进度条
|
|
|
test_loss += loss
|
|
|
correct += (pred.argmax(1) == label).asnumpy().sum()
|