|
@@ -0,0 +1,191 @@
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+import torchvision
|
|
|
+import torchvision.transforms as transforms
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+from torch import optim
|
|
|
+from tqdm import tqdm # 导入tqdm
|
|
|
+
|
|
|
+from models.alexnet import AlexNet
|
|
|
+from models.embeder import WatermarkEmbeder
|
|
|
+from models.googlenet import GoogLeNet
|
|
|
+from models.lenet import LeNet
|
|
|
+from models.vgg16 import VGGNet
|
|
|
+
|
|
|
+# 参数
|
|
|
+batch_size = 500
|
|
|
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+num_epochs = 40
|
|
|
+wm_length = 1024
|
|
|
+
|
|
|
+# 设置随机数种子
|
|
|
+np.random.seed(1)
|
|
|
+lambda1 = 0.05
|
|
|
+b = np.random.randint(low=0, high=2, size=(1, wm_length)) # 生成模拟随机密钥
|
|
|
+np.save('b.npy', b)
|
|
|
+b = nn.Parameter(torch.tensor(b, dtype=torch.float32).to(device), requires_grad=False)
|
|
|
+b.requires_grad = False
|
|
|
+
|
|
|
+# 数据预处理和加载
|
|
|
+transform_train = transforms.Compose([
|
|
|
+ transforms.RandomCrop(32, padding=4),
|
|
|
+ transforms.RandomHorizontalFlip(),
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
|
+])
|
|
|
+
|
|
|
+transform_test = transforms.Compose([
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
|
+])
|
|
|
+
|
|
|
+trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
|
|
|
+trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
|
|
|
+
|
|
|
+testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
|
|
|
+testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
|
|
|
+
|
|
|
+# 创建VGGNet模型实例
|
|
|
+# model = VGGNet(num_classes=10).to(device)
|
|
|
+# model = GoogLeNet(num_classes=10).to(device)
|
|
|
+# model = LeNet().to(device)
|
|
|
+model = AlexNet().to(device)
|
|
|
+
|
|
|
+# 打印模型结构
|
|
|
+print(model)
|
|
|
+
|
|
|
+# 断点续训
|
|
|
+# model.load_state_dict(torch.load("./result/host_dnn.pt"))
|
|
|
+
|
|
|
+# 获取指定层权重信息
|
|
|
+# vgg16
|
|
|
+# weight = model.features[0].weight.view(1, -1)
|
|
|
+# googleNet
|
|
|
+# weight = model.conv1[0].weight.view(1, -1)
|
|
|
+# leNet
|
|
|
+# weight = model.conv1.weight.view(1, -1)
|
|
|
+# AlexNet
|
|
|
+weight = model.layer1[0].weight.view(1, -1)
|
|
|
+
|
|
|
+# 获取权重向量长度
|
|
|
+pct_dim = np.prod(weight.shape[1:4])
|
|
|
+# 创建水印嵌入模型实例
|
|
|
+model_embeder = WatermarkEmbeder(pct_dim, wm_length).to(device)
|
|
|
+# model_embeder.load_state_dict(torch.load("./result/embeder.pt"))
|
|
|
+
|
|
|
+# 定义目标模型损失函数和优化器
|
|
|
+criterion = nn.CrossEntropyLoss()
|
|
|
+
|
|
|
+# 目标模型使用Adam优化器
|
|
|
+optimizer = optim.Adam(model.parameters(), lr=1e-4) # 调整学习率
|
|
|
+optimizer_embeder = optim.Adam(model_embeder.parameters(), lr=0.5)
|
|
|
+
|
|
|
+# 初始化空列表以存储准确度和损失
|
|
|
+train_accs = []
|
|
|
+train_losses = []
|
|
|
+torch.autograd.set_detect_anomaly(True)
|
|
|
+
|
|
|
+for epoch in range(num_epochs):
|
|
|
+ model_embeder.train()
|
|
|
+ model.train()
|
|
|
+ running_loss = 0.0
|
|
|
+ correct = 0
|
|
|
+ total = 0
|
|
|
+
|
|
|
+ # 使用tqdm创建进度条
|
|
|
+ with (tqdm(total=len(trainloader), desc=f"Epoch {epoch + 1}", unit="batch") as pbar):
|
|
|
+ for i, data in enumerate(trainloader, 0):
|
|
|
+ inputs, labels = data
|
|
|
+ inputs, labels = inputs.to(device), labels.to(device)
|
|
|
+
|
|
|
+ optimizer_embeder.zero_grad()
|
|
|
+ outputs_embeder = model_embeder(weight)
|
|
|
+ loss_embeder = F.binary_cross_entropy(outputs_embeder, b)
|
|
|
+ # loss_embeder.backward(retain_graph=True)
|
|
|
+ # optimizer_embeder.step()
|
|
|
+
|
|
|
+ optimizer.zero_grad()
|
|
|
+ outputs = model(inputs)
|
|
|
+ loss_c = criterion(outputs, labels)
|
|
|
+ loss_h = loss_c + lambda1 * loss_embeder
|
|
|
+ loss_h.backward()
|
|
|
+ optimizer.step()
|
|
|
+ optimizer_embeder.step()
|
|
|
+
|
|
|
+ running_loss += loss_h.item()
|
|
|
+
|
|
|
+ _, predicted = torch.max(outputs.data, 1)
|
|
|
+ total += labels.size(0)
|
|
|
+ correct += (predicted == labels).sum().item()
|
|
|
+
|
|
|
+ # 更新进度条
|
|
|
+ pbar.set_postfix(loss=running_loss / (i + 1), loss_em=loss_embeder.item(), acc=100 * correct / total)
|
|
|
+ pbar.update()
|
|
|
+
|
|
|
+ # 计算准确度和损失
|
|
|
+ epoch_acc = 100 * correct / total
|
|
|
+ epoch_loss = running_loss / len(trainloader)
|
|
|
+
|
|
|
+ # 记录准确度和损失值
|
|
|
+ train_accs.append(epoch_acc)
|
|
|
+ train_losses.append(epoch_loss)
|
|
|
+ print(f"Epoch {epoch + 1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}%")
|
|
|
+
|
|
|
+ torch.save(model.state_dict(), "./result/host_dnn.pt")
|
|
|
+ torch.save(model_embeder.state_dict(), "./result/embeder.pt")
|
|
|
+
|
|
|
+ # 导出onnx格式
|
|
|
+ with torch.no_grad():
|
|
|
+ x = torch.randn(1, 3, 32, 32).to(device)
|
|
|
+ torch.onnx.export(model, x, 'host_dnn.onnx', opset_version=11, input_names=['input'],
|
|
|
+ output_names=['output'])
|
|
|
+ torch.onnx.export(model_embeder, weight, 'embeder.onnx', opset_version=11, input_names=['input'],
|
|
|
+ output_names=['output'])
|
|
|
+ # 更新学习率
|
|
|
+ # scheduler.step()
|
|
|
+
|
|
|
+ # 测试模型
|
|
|
+ if epoch % 5 == 4:
|
|
|
+ model.eval()
|
|
|
+ correct = 0
|
|
|
+ total = 0
|
|
|
+ with torch.no_grad():
|
|
|
+ for data in testloader:
|
|
|
+ inputs, labels = data
|
|
|
+ inputs, labels = inputs.to(device), labels.to(device)
|
|
|
+ outputs = model(inputs)
|
|
|
+ _, predicted = torch.max(outputs.data, 1)
|
|
|
+ total += labels.size(0)
|
|
|
+ correct += (predicted == labels).sum().item()
|
|
|
+
|
|
|
+ print(f"Accuracy on test set: {(100 * correct / total):.2f}%")
|
|
|
+ # scheduler.step()
|
|
|
+
|
|
|
+print("Finished Training")
|
|
|
+
|
|
|
+# 绘制准确度和损失曲线
|
|
|
+plt.figure(figsize=(12, 4))
|
|
|
+plt.subplot(1, 2, 1)
|
|
|
+plt.plot(train_accs)
|
|
|
+plt.title('Training Accuracy')
|
|
|
+plt.xlabel('Epoch')
|
|
|
+plt.ylabel('Accuracy (%)')
|
|
|
+
|
|
|
+plt.subplot(1, 2, 2)
|
|
|
+plt.plot(train_losses)
|
|
|
+plt.title('Training Loss')
|
|
|
+plt.xlabel('Epoch')
|
|
|
+plt.ylabel('Loss')
|
|
|
+
|
|
|
+plt.tight_layout()
|
|
|
+plt.show()
|
|
|
+
|
|
|
+print("Finished drawing")
|
|
|
+# 测试水印嵌入
|
|
|
+model_embeder.eval()
|
|
|
+outputs_embeder = model_embeder(weight)
|
|
|
+outputs_embeder = (outputs_embeder > 0.5).int()
|
|
|
+wrr_bits = (outputs_embeder != b).sum().item()
|
|
|
+print(f'wrr_bits={wrr_bits}')
|