val.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. """
  2. 测试白盒水印标签
  3. """
  4. import torch
  5. import torchvision
  6. from torch import nn
  7. import torchvision.transforms as transforms
  8. from model.Alexnet import Alexnet
  9. from watermark_codec import ModelDecoder
  10. from watermark_codec.tool import secret_func
  11. model_path = './run/train/alex_net.pt'
  12. key_path = './run/train/key.pt'
  13. device = 'cuda'
  14. # 测试集转换
  15. transform_test = transforms.Compose([
  16. transforms.ToTensor(),
  17. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  18. ])
  19. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
  20. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
  21. # 从指定权重文件加载模型,测试水印嵌入
  22. model = Alexnet(3, 10, 32).to(device)
  23. model.load_state_dict(torch.load(model_path))
  24. # 获取模型中待嵌入的卷积层
  25. conv_list = []
  26. for module in model.modules():
  27. if isinstance(module, nn.Conv2d):
  28. conv_list.append(module)
  29. conv_list = conv_list[0:2]
  30. # 初始化白盒水印解码器
  31. decoder = ModelDecoder(layers=conv_list, key_path=key_path, device=device)
  32. secret_extract = decoder.decode() # 提取密码标签
  33. print(f"secret_extract: {secret_extract}")
  34. if secret_func.verify_secret(secret_extract):
  35. print('密码标签验证成功')
  36. else:
  37. print('验证失败')
  38. # 测试模型
  39. model.eval()
  40. correct = 0
  41. total = 0
  42. with torch.no_grad():
  43. for data in testloader:
  44. inputs, labels = data
  45. inputs, labels = inputs.to(device), labels.to(device)
  46. outputs = model(inputs)
  47. _, predicted = torch.max(outputs.data, 1)
  48. total += labels.size(0)
  49. correct += (predicted == labels).sum().item()
  50. print(f"Accuracy on test set: {(100 * correct / total):.2f}%")