predict_pt_embed.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import os
  2. import time
  3. import numpy as np
  4. import torch
  5. import argparse
  6. from PIL import Image
  7. from torch import nn
  8. from torchvision import transforms
  9. from watermark_codec import ModelDecoder
  10. from block import secret_get
  11. # -------------------------------------------------------------------------------------------------------------------- #
  12. parser = argparse.ArgumentParser(description='|pt模型推理|')
  13. parser.add_argument('--model_path', default='./checkpoints/Alexnet/wm_embed/best.pt', type=str, help='|pt模型位置|')
  14. parser.add_argument('--key_path', default='./checkpoints/Alexnet/wm_embed/key.pt', type=str, help='|投影矩阵位置|')
  15. parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
  16. parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
  17. parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
  18. parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
  19. parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
  20. args, _ = parser.parse_known_args() # 防止传入参数冲突,替代args = parser.parse_args()
  21. # -------------------------------------------------------------------------------------------------------------------- #
  22. assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
  23. assert os.path.exists(args.key_path), f'! key_path:{args.key_path} !'
  24. assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
  25. if args.float16:
  26. assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
  27. # -------------------------------------------------------------------------------------------------------------------- #
  28. def predict_pt(args):
  29. # 加载模型
  30. model_dict = torch.load(args.model_path, map_location='cpu')
  31. model = model_dict['model']
  32. model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
  33. epoch = model_dict['epoch_finished']
  34. accuracy = round(model_dict['standard'], 4)
  35. print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | accuracy:{accuracy} |')
  36. # 选择加密层并初始化白盒水印编码器
  37. conv_list = model_dict['enc_layers']
  38. decoder = ModelDecoder(layers=conv_list, key_path=args.key_path, device=args.device) # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
  39. secret_extract = decoder.decode() # 提取密码标签
  40. result = secret_get.verify_secret(secret_extract)
  41. print(f"白盒水印验证结果: {result}, 提取的密码标签为: {secret_extract}")
  42. # 推理
  43. start_time = time.time()
  44. with torch.no_grad():
  45. print(f"加载测试集至内存...")
  46. transform = transforms.Compose([
  47. transforms.ToTensor(), # 将图像转换为PyTorch张量
  48. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  49. ])
  50. dataset = CustomDataset(data_dir=args.data_path, transform=transform)
  51. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
  52. shuffle=False, drop_last=False, pin_memory=False,
  53. num_workers=args.num_worker)
  54. print(f"加载测试集完成,开始预测...")
  55. correct = 0
  56. total = 0
  57. epoch = 0
  58. for index, (image_batch, true_batch) in enumerate(dataloader):
  59. image_batch = image_batch.to(args.device)
  60. pred_batch = model(image_batch).detach().cpu()
  61. # 获取指标项
  62. _, predicted = torch.max(pred_batch, 1)
  63. total += true_batch.size(0)
  64. correct += (predicted == true_batch).sum().item()
  65. epoch = epoch + 1
  66. # 计算指标
  67. accuracy = correct / total
  68. end_time = time.time()
  69. print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
  70. class CustomDataset(torch.utils.data.Dataset):
  71. """
  72. 自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别
  73. """
  74. def __init__(self, data_dir, image_size=(32, 32), transform=None):
  75. self.data_dir = data_dir
  76. self.image_size = image_size
  77. self.transform = transform
  78. self.images = []
  79. self.labels = []
  80. # 遍历指定目录下的子目录,每个子目录代表一个类别
  81. class_dirs = sorted(os.listdir(data_dir))
  82. for index, class_dir in enumerate(class_dirs):
  83. class_path = os.path.join(data_dir, class_dir)
  84. # 遍历当前类别目录下的图像文件
  85. for image_file in os.listdir(class_path):
  86. image_path = os.path.join(class_path, image_file)
  87. # 使用PIL加载图像并调整大小
  88. image = Image.open(image_path).convert('RGB')
  89. image = image.resize(image_size)
  90. self.images.append(np.array(image))
  91. self.labels.append(index)
  92. def __len__(self):
  93. return len(self.images)
  94. def __getitem__(self, idx):
  95. image = self.images[idx]
  96. label = self.labels[idx]
  97. if self.transform:
  98. image = self.transform(Image.fromarray(image))
  99. return image, label
  100. if __name__ == '__main__':
  101. predict_pt(args)