train.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. 示例代码,演示白盒水印编码器与解码器的使用
  3. """
  4. import os
  5. import torch
  6. import torch.nn as nn
  7. import torchvision
  8. import torchvision.transforms as transforms
  9. from matplotlib import pyplot as plt
  10. from torch import optim
  11. from tqdm import tqdm # 导入tqdm
  12. import secret_func
  13. from model.Alexnet import Alexnet
  14. from watermark_codec import ModelEncoder
  15. # 参数
  16. batch_size = 500
  17. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  18. num_epochs = 40
  19. wm_length = 1024
  20. num_workers = 2
  21. # 设置随机数种子
  22. # np.random.seed(1)
  23. # lambda1 = 0.05
  24. # b = np.random.randint(low=0, high=2, size=(1, wm_length)) # 生成模拟随机密钥
  25. # np.save('b.npy', b)
  26. # b = nn.Parameter(torch.tensor(b, dtype=torch.float32).to(device), requires_grad=False)
  27. # b.requires_grad = False
  28. # 存储路径
  29. model_path = './run/train/alex_net.pt'
  30. key_path = './run/train/key.pt'
  31. os.makedirs(os.path.dirname(model_path), exist_ok=True)
  32. os.makedirs(os.path.dirname(key_path), exist_ok=True)
  33. # 数据预处理和加载
  34. transform_train = transforms.Compose([
  35. transforms.RandomCrop(32, padding=4),
  36. transforms.RandomHorizontalFlip(),
  37. transforms.ToTensor(),
  38. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  39. ])
  40. transform_test = transforms.Compose([
  41. transforms.ToTensor(),
  42. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  43. ])
  44. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
  45. trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  46. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
  47. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=num_workers)
  48. # 创建AlexNet模型实例
  49. model = Alexnet(3, 10, 32).to(device)
  50. print(model)
  51. # 获取模型中待嵌入的卷积层
  52. conv_list = []
  53. for module in model.modules():
  54. if isinstance(module, nn.Conv2d):
  55. conv_list.append(module)
  56. conv_list = conv_list[0:2]
  57. # 创建模型水印编码器
  58. secret = secret_func.get_secret(512)
  59. encoder = ModelEncoder(layers=conv_list, secret=secret, key_path=key_path, device='cuda')
  60. # 定义目标模型损失函数和优化器
  61. criterion = nn.CrossEntropyLoss()
  62. # 目标模型使用Adam优化器
  63. optimizer = optim.Adam(model.parameters(), lr=1e-4) # 调整学习率
  64. # 初始化空列表以存储准确度和损失
  65. train_accs = []
  66. train_losses = []
  67. torch.autograd.set_detect_anomaly(True)
  68. for epoch in range(num_epochs):
  69. model.train()
  70. running_loss = 0.0
  71. correct = 0
  72. total = 0
  73. # 使用tqdm创建进度条
  74. with tqdm(total=len(trainloader), desc=f"Epoch {epoch + 1}", unit="batch") as pbar:
  75. for i, data in enumerate(trainloader, 0):
  76. inputs, labels = data
  77. inputs, labels = inputs.to(device), labels.to(device)
  78. optimizer.zero_grad()
  79. outputs = model(inputs)
  80. loss = criterion(outputs, labels)
  81. # loss = encoder.get_loss(loss) # 实际应用只调用get_loss修改原损失即可
  82. # 测试时可以获取白盒水印的损失并打印 ------------------------------
  83. loss_embeder = encoder.get_embeder_loss()
  84. loss += loss_embeder
  85. # -----------------------------------------------------------
  86. loss.backward()
  87. optimizer.step()
  88. running_loss += loss.item()
  89. _, predicted = torch.max(outputs.data, 1)
  90. total += labels.size(0)
  91. correct += (predicted == labels).sum().item()
  92. # 更新进度条
  93. pbar.set_postfix(loss=running_loss / (i + 1), loss_embeder=loss_embeder.item(), acc=100 * correct / total)
  94. pbar.update()
  95. # 计算准确度和损失
  96. epoch_acc = 100 * correct / total
  97. epoch_loss = running_loss / len(trainloader)
  98. # 记录准确度和损失值
  99. train_accs.append(epoch_acc)
  100. train_losses.append(epoch_loss)
  101. print(f"Epoch {epoch + 1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}%")
  102. torch.save(model.state_dict(), model_path)
  103. # 测试模型
  104. if epoch % 5 == 4:
  105. model.eval()
  106. correct = 0
  107. total = 0
  108. with torch.no_grad():
  109. for data in testloader:
  110. inputs, labels = data
  111. inputs, labels = inputs.to(device), labels.to(device)
  112. outputs = model(inputs)
  113. _, predicted = torch.max(outputs.data, 1)
  114. total += labels.size(0)
  115. correct += (predicted == labels).sum().item()
  116. print(f"Accuracy on test set: {(100 * correct / total):.2f}%")
  117. print("Finished Training")
  118. # 绘制准确度和损失曲线
  119. plt.figure(figsize=(12, 4))
  120. plt.subplot(1, 2, 1)
  121. plt.plot(train_accs)
  122. plt.title('Training Accuracy')
  123. plt.xlabel('Epoch')
  124. plt.ylabel('Accuracy (%)')
  125. plt.subplot(1, 2, 2)
  126. plt.plot(train_losses)
  127. plt.title('Training Loss')
  128. plt.xlabel('Epoch')
  129. plt.ylabel('Loss')
  130. plt.tight_layout()
  131. plt.show()
  132. print("Finished drawing")