predict_pt_embed.py 4.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. import time
  3. import torch
  4. import argparse
  5. from torchvision import transforms
  6. from watermark_codec import ModelDecoder
  7. from block import secret_get
  8. from block.dataset_get import CustomDataset
  9. # -------------------------------------------------------------------------------------------------------------------- #
  10. parser = argparse.ArgumentParser(description='|pt模型推理|')
  11. parser.add_argument('--model_path', default='./checkpoints/Alexnet/wm_embed/best.pt', type=str, help='|pt模型位置|')
  12. parser.add_argument('--key_path', default='./checkpoints/Alexnet/wm_embed/key.pt', type=str, help='|投影矩阵位置|')
  13. parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
  14. parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
  15. parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
  16. parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
  17. parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
  18. parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|')
  19. args, _ = parser.parse_known_args() # 防止传入参数冲突,替代args = parser.parse_args()
  20. # -------------------------------------------------------------------------------------------------------------------- #
  21. assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !'
  22. assert os.path.exists(args.key_path), f'! key_path:{args.key_path} !'
  23. assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !'
  24. if args.float16:
  25. assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16'
  26. # -------------------------------------------------------------------------------------------------------------------- #
  27. def predict_pt(args):
  28. # 加载模型
  29. model_dict = torch.load(args.model_path, map_location='cpu')
  30. model = model_dict['model']
  31. model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device)
  32. epoch = model_dict['epoch_finished']
  33. accuracy = round(model_dict['standard'], 4)
  34. print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | accuracy:{accuracy} |')
  35. # 选择加密层并初始化白盒水印编码器
  36. conv_list = model_dict['enc_layers']
  37. decoder = ModelDecoder(layers=conv_list, key_path=args.key_path, device=args.device) # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu)
  38. secret_extract = decoder.decode() # 提取密码标签
  39. result = secret_get.verify_secret(secret_extract)
  40. print(f"白盒水印验证结果: {result}, 提取的密码标签为: {secret_extract}")
  41. # 推理
  42. start_time = time.time()
  43. with torch.no_grad():
  44. print(f"加载测试集至内存...")
  45. transform = transforms.Compose([
  46. transforms.ToTensor(), # 将图像转换为PyTorch张量
  47. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
  48. ])
  49. dataset = CustomDataset(data_dir=args.data_path, image_size=(args.input_size, args.input_size), transform=transform)
  50. dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
  51. shuffle=False, drop_last=False, pin_memory=False,
  52. num_workers=args.num_worker)
  53. print(f"加载测试集完成,开始预测...")
  54. correct = 0
  55. total = 0
  56. epoch = 0
  57. for index, (image_batch, true_batch) in enumerate(dataloader):
  58. image_batch = image_batch.to(args.device)
  59. pred_batch = model(image_batch).detach().cpu()
  60. # 获取指标项
  61. _, predicted = torch.max(pred_batch, 1)
  62. total += true_batch.size(0)
  63. correct += (predicted == true_batch).sum().item()
  64. epoch = epoch + 1
  65. # 计算指标
  66. accuracy = correct / total
  67. end_time = time.time()
  68. print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
  69. if __name__ == '__main__':
  70. predict_pt(args)