import mindspore import mindspore.nn as nn from mindspore.dataset import vision, Cifar10Dataset from mindspore.dataset.vision import transforms import tqdm from tests.model.AlexNet import AlexNet from tests.secret_func import get_secret from watermark_codec import ModelEncoder mindspore.set_context(device_target="CPU", max_device_memory="4GB") train_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='train', shuffle=True) test_dataset = Cifar10Dataset(dataset_dir='data/cifar-10-batches-bin', usage='test') batch_size = 32 def datapipe(dataset, batch_size): image_transforms = [ vision.Rescale(1.0 / 255.0, 0), vision.Normalize(mean=(0.1307,), std=(0.3081,)), vision.HWC2CHW() ] label_transform = transforms.TypeCast(mindspore.int32) dataset = dataset.map(image_transforms, 'image') dataset = dataset.map(label_transform, 'label') dataset = dataset.batch(batch_size) return dataset # Map vision transforms and batch dataset train_dataset = datapipe(train_dataset, batch_size) test_dataset = datapipe(test_dataset, batch_size) # Define model model = AlexNet(input_channels=3, output_num=10, input_size=32) print(model) # init white_box watermark encoder layers = [] secret = get_secret(512) key_path = './run/train/key.ckpt' for name, layer in model.cells_and_names(): if isinstance (layer, nn.Conv2d): layers.append(layer) model_encoder = ModelEncoder(layers=layers[0:2], secret=secret, key_path=key_path) # Instantiate loss function and optimizer loss_fn = nn.CrossEntropyLoss() optimizer = nn.Adam(model.trainable_params(), 1e-2) # 1. Define forward function def forward_fn(data, label): logits = model(data) loss_embed = model_encoder.get_embeder_loss() loss = loss_fn(logits, label) + loss_embed return loss, loss_embed, logits # 2. Get gradient function grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) # 3. Define function of one-step training def train_step(data, label): (loss, loss_embed, _), grads = grad_fn(data, label) optimizer(grads) return loss, loss_embed def train(model, dataset): num_batches = dataset.get_dataset_size() tqdm_show = tqdm.tqdm(total=num_batches, desc='train', unit='batch') model.set_train() for batch, (data, label) in enumerate(dataset.create_tuple_iterator()): loss, loss_embed = train_step(data, label) tqdm_show.set_postfix({'train_loss': loss.asnumpy(), 'wm_loss': loss_embed.asnumpy()}) # 添加显示 tqdm_show.update(1) # 更新进度条 tqdm_show.close() def test(model, dataset, loss_fn): num_batches = dataset.get_dataset_size() model.set_train(False) total, test_loss, correct = 0, 0, 0 tqdm_show = tqdm.tqdm(total=num_batches, desc='test', unit='batch') for data, label in dataset.create_tuple_iterator(): pred = model(data) total += len(data) loss_embed = model_encoder.get_embeder_loss().asnumpy() loss = loss_fn(pred, label).asnumpy() + loss_embed tqdm_show.set_postfix({'val_loss': loss, 'wm_loss': loss_embed}) # 添加显示 tqdm_show.update(1) # 更新进度条 test_loss += loss correct += (pred.argmax(1) == label).asnumpy().sum() test_loss /= num_batches correct /= total tqdm_show.write(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") tqdm_show.close() def save(model, save_path): mindspore.save_checkpoint(model, save_path) def export_onnx(model): mindspore.export(model, train_dataset, file_name='./run/train/AlexNet.onnx', file_format='ONNX') if __name__ == '__main__': epochs = 50 save_path = "./run/train/AlexNet.ckpt" for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train(model, train_dataset) test(model, test_dataset, loss_fn) save(model, save_path) export_onnx(model) print("Done!")