import os import time import numpy as np import torch import argparse from PIL import Image from torch import nn from torchvision import transforms from watermark_codec import ModelDecoder from block import secret_get # -------------------------------------------------------------------------------------------------------------------- # parser = argparse.ArgumentParser(description='|pt模型推理|') parser.add_argument('--model_path', default='./checkpoints/Alexnet/wm_embed/best.pt', type=str, help='|pt模型位置|') parser.add_argument('--key_path', default='./checkpoints/Alexnet/wm_embed/key.pt', type=str, help='|投影矩阵位置|') parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|') parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|') parser.add_argument('--device', default='cuda', type=str, help='|推理设备|') parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|') parser.add_argument('--float16', default=False, type=bool, help='|推理数据类型,要支持float16的GPU,False时为float32|') args, _ = parser.parse_known_args() # 防止传入参数冲突,替代args = parser.parse_args() # -------------------------------------------------------------------------------------------------------------------- # assert os.path.exists(args.model_path), f'! model_path不存在:{args.model_path} !' assert os.path.exists(args.key_path), f'! key_path:{args.key_path} !' assert os.path.exists(args.data_path), f'! data_path不存在:{args.data_path} !' if args.float16: assert torch.cuda.is_available(), 'cuda不可用,因此无法使用float16' # -------------------------------------------------------------------------------------------------------------------- # def predict_pt(args): # 加载模型 model_dict = torch.load(args.model_path, map_location='cpu') model = model_dict['model'] model.half().eval().to(args.device) if args.float16 else model.float().eval().to(args.device) epoch = model_dict['epoch_finished'] accuracy = round(model_dict['standard'], 4) print(f'| 模型加载成功:{args.model_path} | epoch:{epoch} | accuracy:{accuracy} |') # 选择加密层并初始化白盒水印编码器 conv_list = [] for module in model.modules(): if isinstance(module, nn.Conv2d): conv_list.append(module) conv_list = conv_list[0:2] decoder = ModelDecoder(layers=conv_list, key_path=args.key_path, device=args.device) # 传入待嵌入的卷积层列表,编码器生成密钥路径,运算设备(cuda/cpu) secret_extract = decoder.decode() # 提取密码标签 result = secret_get.verify_secret(secret_extract) print(f"白盒水印验证结果: {result}, 提取的密码标签为: {secret_extract}") # 推理 start_time = time.time() with torch.no_grad(): print(f"加载测试集至内存...") transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为PyTorch张量 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化 ]) dataset = CustomDataset(data_dir=args.data_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch, shuffle=False, drop_last=False, pin_memory=False, num_workers=args.num_worker) print(f"加载测试集完成,开始预测...") correct = 0 total = 0 epoch = 0 for index, (image_batch, true_batch) in enumerate(dataloader): image_batch = image_batch.to(args.device) pred_batch = model(image_batch).detach().cpu() # 获取指标项 _, predicted = torch.max(pred_batch, 1) total += true_batch.size(0) correct += (predicted == true_batch).sum().item() epoch = epoch + 1 # 计算指标 accuracy = correct / total end_time = time.time() print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ') class CustomDataset(torch.utils.data.Dataset): """ 自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别 """ def __init__(self, data_dir, image_size=(32, 32), transform=None): self.data_dir = data_dir self.image_size = image_size self.transform = transform self.images = [] self.labels = [] # 遍历指定目录下的子目录,每个子目录代表一个类别 class_dirs = sorted(os.listdir(data_dir)) for index, class_dir in enumerate(class_dirs): class_path = os.path.join(data_dir, class_dir) # 遍历当前类别目录下的图像文件 for image_file in os.listdir(class_path): image_path = os.path.join(class_path, image_file) # 使用PIL加载图像并调整大小 image = Image.open(image_path).convert('RGB') image = image.resize(image_size) self.images.append(np.array(image)) self.labels.append(index) def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] label = self.labels[idx] if self.transform: image = self.transform(Image.fromarray(image)) return image, label if __name__ == '__main__': predict_pt(args)