Selaa lähdekoodia

修改训练流程,添加水印嵌入流程

liyan 11 kuukautta sitten
vanhempi
commit
37215e3393
1 muutettua tiedostoa jossa 21 lisäystä ja 8 poistoa
  1. 21 8
      tests/train.py

+ 21 - 8
tests/train.py

@@ -5,6 +5,8 @@ from mindspore.dataset.vision import transforms
 import tqdm
 
 from tests.model.AlexNet import AlexNet
+from tests.secret_func import get_secret
+from watermark_codec import ModelEncoder
 
 train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
 test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
@@ -32,6 +34,15 @@ test_dataset = datapipe(test_dataset, batch_size)
 model = AlexNet(input_channels=3, output_num=10, input_size=32)
 print(model)
 
+# init white_box watermark encoder
+layers = []
+secret = get_secret(512)
+key_path = './run/train/key.ckpt'
+for name, layer in model.cells_and_names():
+    if isinstance (layer, nn.Conv2d):
+        layers.append(layer)
+model_encoder = ModelEncoder(layers=layers[0:2], secret=secret, key_path=key_path)
+
 # Instantiate loss function and optimizer
 loss_fn = nn.CrossEntropyLoss()
 optimizer = nn.Adam(model.trainable_params(), 1e-2)
@@ -40,8 +51,9 @@ optimizer = nn.Adam(model.trainable_params(), 1e-2)
 # 1. Define forward function
 def forward_fn(data, label):
     logits = model(data)
-    loss = loss_fn(logits, label)
-    return loss, logits
+    loss_embed = model_encoder.get_embeder_loss()
+    loss = loss_fn(logits, label) + loss_embed
+    return loss, loss_embed, logits
 
 
 # 2. Get gradient function
@@ -50,9 +62,9 @@ grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_a
 
 # 3. Define function of one-step training
 def train_step(data, label):
-    (loss, _), grads = grad_fn(data, label)
+    (loss, loss_embed, _), grads = grad_fn(data, label)
     optimizer(grads)
-    return loss
+    return loss, loss_embed
 
 
 def train(model, dataset):
@@ -60,8 +72,8 @@ def train(model, dataset):
     tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch')
     model.set_train()
     for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
-        loss = train_step(data, label)
-        tqdm_show.set_postfix({'train_loss': loss.asnumpy()})  # 添加显示
+        loss, loss_embed = train_step(data, label)
+        tqdm_show.set_postfix({'train_loss': loss.asnumpy(), 'wm_loss': loss_embed.asnumpy()})  # 添加显示
         tqdm_show.update(1)  # 更新进度条
     tqdm_show.close()
 
@@ -74,8 +86,9 @@ def test(model, dataset, loss_fn):
     for data, label in dataset.create_tuple_iterator():
         pred = model(data)
         total += len(data)
-        loss = loss_fn(pred, label).asnumpy()
-        tqdm_show.set_postfix({'val_loss': loss})  # 添加显示
+        loss_embed = model_encoder.get_embeder_loss().asnumpy()
+        loss = loss_fn(pred, label).asnumpy() + loss_embed
+        tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed.asnumpy()})  # 添加显示
         tqdm_show.update(1)  # 更新进度条
         test_loss += loss
         correct += (pred.argmax(1) == label).asnumpy().sum()