|
@@ -8,6 +8,7 @@ from tests.model.AlexNet import AlexNet
|
|
from tests.secret_func import get_secret
|
|
from tests.secret_func import get_secret
|
|
from watermark_codec import ModelEncoder
|
|
from watermark_codec import ModelEncoder
|
|
|
|
|
|
|
|
+mindspore.set_context(device_target="CPU", max_device_memory="4GB")
|
|
train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
|
|
train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True)
|
|
test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
|
|
test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test')
|
|
batch_size = 32
|
|
batch_size = 32
|
|
@@ -88,7 +89,7 @@ def test(model, dataset, loss_fn):
|
|
total += len(data)
|
|
total += len(data)
|
|
loss_embed = model_encoder.get_embeder_loss().asnumpy()
|
|
loss_embed = model_encoder.get_embeder_loss().asnumpy()
|
|
loss = loss_fn(pred, label).asnumpy() + loss_embed
|
|
loss = loss_fn(pred, label).asnumpy() + loss_embed
|
|
- tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed.asnumpy()}) # 添加显示
|
|
|
|
|
|
+ tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed}) # 添加显示
|
|
tqdm_show.update(1) # 更新进度条
|
|
tqdm_show.update(1) # 更新进度条
|
|
test_loss += loss
|
|
test_loss += loss
|
|
correct += (pred.argmax(1) == label).asnumpy().sum()
|
|
correct += (pred.argmax(1) == label).asnumpy().sum()
|
|
@@ -97,12 +98,16 @@ def test(model, dataset, loss_fn):
|
|
tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
|
tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
|
tqdm_show.close()
|
|
tqdm_show.close()
|
|
|
|
|
|
|
|
+def save(model, save_path):
|
|
|
|
+ mindspore.save_checkpoint(model, save_path)
|
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
- epochs = 10
|
|
|
|
- mindspore.set_context(device_target="GPU")
|
|
|
|
|
|
+ epochs = 50
|
|
|
|
+ save_path = "./run/train/AlexNet.ckpt"
|
|
for t in range(epochs):
|
|
for t in range(epochs):
|
|
print(f"Epoch {t + 1}\n-------------------------------")
|
|
print(f"Epoch {t + 1}\n-------------------------------")
|
|
train(model, train_dataset)
|
|
train(model, train_dataset)
|
|
test(model, test_dataset, loss_fn)
|
|
test(model, test_dataset, loss_fn)
|
|
|
|
+ save(model, save_path)
|
|
print("Done!")
|
|
print("Done!")
|