train.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torchvision
  6. import torchvision.transforms as transforms
  7. from matplotlib import pyplot as plt
  8. from torch import optim
  9. from tqdm import tqdm # 导入tqdm
  10. from models.alexnet import AlexNet
  11. from models.embeder import WatermarkEmbeder
  12. from models.googlenet import GoogLeNet
  13. from models.lenet import LeNet
  14. from models.vgg16 import VGGNet
  15. # 参数
  16. batch_size = 500
  17. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  18. num_epochs = 40
  19. wm_length = 1024
  20. # 设置随机数种子
  21. np.random.seed(1)
  22. lambda1 = 0.05
  23. b = np.random.randint(low=0, high=2, size=(1, wm_length)) # 生成模拟随机密钥
  24. np.save('b.npy', b)
  25. b = nn.Parameter(torch.tensor(b, dtype=torch.float32).to(device), requires_grad=False)
  26. b.requires_grad = False
  27. # 数据预处理和加载
  28. transform_train = transforms.Compose([
  29. transforms.RandomCrop(32, padding=4),
  30. transforms.RandomHorizontalFlip(),
  31. transforms.ToTensor(),
  32. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  33. ])
  34. transform_test = transforms.Compose([
  35. transforms.ToTensor(),
  36. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  37. ])
  38. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
  39. trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
  40. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
  41. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
  42. # 创建VGGNet模型实例
  43. # model = VGGNet(num_classes=10).to(device)
  44. # model = GoogLeNet(num_classes=10).to(device)
  45. # model = LeNet().to(device)
  46. model = AlexNet().to(device)
  47. # 打印模型结构
  48. print(model)
  49. # 断点续训
  50. # model.load_state_dict(torch.load("./result/host_dnn.pt"))
  51. # 获取指定层权重信息
  52. # vgg16
  53. # weight = model.features[0].weight.view(1, -1)
  54. # googleNet
  55. # weight = model.conv1[0].weight.view(1, -1)
  56. # leNet
  57. # weight = model.conv1.weight.view(1, -1)
  58. # AlexNet
  59. weight = model.layer1[0].weight.view(1, -1)
  60. # 获取权重向量长度
  61. pct_dim = np.prod(weight.shape[1:4])
  62. # 创建水印嵌入模型实例
  63. model_embeder = WatermarkEmbeder(pct_dim, wm_length).to(device)
  64. # model_embeder.load_state_dict(torch.load("./result/embeder.pt"))
  65. # 定义目标模型损失函数和优化器
  66. criterion = nn.CrossEntropyLoss()
  67. # 目标模型使用Adam优化器
  68. optimizer = optim.Adam(model.parameters(), lr=1e-4) # 调整学习率
  69. optimizer_embeder = optim.Adam(model_embeder.parameters(), lr=0.5)
  70. # 初始化空列表以存储准确度和损失
  71. train_accs = []
  72. train_losses = []
  73. torch.autograd.set_detect_anomaly(True)
  74. for epoch in range(num_epochs):
  75. model_embeder.train()
  76. model.train()
  77. running_loss = 0.0
  78. correct = 0
  79. total = 0
  80. # 使用tqdm创建进度条
  81. with (tqdm(total=len(trainloader), desc=f"Epoch {epoch + 1}", unit="batch") as pbar):
  82. for i, data in enumerate(trainloader, 0):
  83. inputs, labels = data
  84. inputs, labels = inputs.to(device), labels.to(device)
  85. optimizer_embeder.zero_grad()
  86. outputs_embeder = model_embeder(weight)
  87. loss_embeder = F.binary_cross_entropy(outputs_embeder, b)
  88. # loss_embeder.backward(retain_graph=True)
  89. # optimizer_embeder.step()
  90. optimizer.zero_grad()
  91. outputs = model(inputs)
  92. loss_c = criterion(outputs, labels)
  93. loss_h = loss_c + lambda1 * loss_embeder
  94. loss_h.backward()
  95. optimizer.step()
  96. optimizer_embeder.step()
  97. running_loss += loss_h.item()
  98. _, predicted = torch.max(outputs.data, 1)
  99. total += labels.size(0)
  100. correct += (predicted == labels).sum().item()
  101. # 更新进度条
  102. pbar.set_postfix(loss=running_loss / (i + 1), loss_em=loss_embeder.item(), acc=100 * correct / total)
  103. pbar.update()
  104. # 计算准确度和损失
  105. epoch_acc = 100 * correct / total
  106. epoch_loss = running_loss / len(trainloader)
  107. # 记录准确度和损失值
  108. train_accs.append(epoch_acc)
  109. train_losses.append(epoch_loss)
  110. print(f"Epoch {epoch + 1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}%")
  111. torch.save(model.state_dict(), "./result/host_dnn.pt")
  112. torch.save(model_embeder.state_dict(), "./result/embeder.pt")
  113. # 导出onnx格式
  114. with torch.no_grad():
  115. x = torch.randn(1, 3, 32, 32).to(device)
  116. torch.onnx.export(model, x, 'host_dnn.onnx', opset_version=11, input_names=['input'],
  117. output_names=['output'])
  118. torch.onnx.export(model_embeder, weight, 'embeder.onnx', opset_version=11, input_names=['input'],
  119. output_names=['output'])
  120. # 更新学习率
  121. # scheduler.step()
  122. # 测试模型
  123. if epoch % 5 == 4:
  124. model.eval()
  125. correct = 0
  126. total = 0
  127. with torch.no_grad():
  128. for data in testloader:
  129. inputs, labels = data
  130. inputs, labels = inputs.to(device), labels.to(device)
  131. outputs = model(inputs)
  132. _, predicted = torch.max(outputs.data, 1)
  133. total += labels.size(0)
  134. correct += (predicted == labels).sum().item()
  135. print(f"Accuracy on test set: {(100 * correct / total):.2f}%")
  136. # scheduler.step()
  137. print("Finished Training")
  138. # 绘制准确度和损失曲线
  139. plt.figure(figsize=(12, 4))
  140. plt.subplot(1, 2, 1)
  141. plt.plot(train_accs)
  142. plt.title('Training Accuracy')
  143. plt.xlabel('Epoch')
  144. plt.ylabel('Accuracy (%)')
  145. plt.subplot(1, 2, 2)
  146. plt.plot(train_losses)
  147. plt.title('Training Loss')
  148. plt.xlabel('Epoch')
  149. plt.ylabel('Loss')
  150. plt.tight_layout()
  151. plt.show()
  152. print("Finished drawing")
  153. # 测试水印嵌入
  154. model_embeder.eval()
  155. outputs_embeder = model_embeder(weight)
  156. outputs_embeder = (outputs_embeder > 0.5).int()
  157. wrr_bits = (outputs_embeder != b).sum().item()
  158. print(f'wrr_bits={wrr_bits}')