|
@@ -1,21 +1,20 @@
|
|
import os
|
|
import os
|
|
import time
|
|
import time
|
|
|
|
|
|
-import numpy as np
|
|
|
|
import torch
|
|
import torch
|
|
import argparse
|
|
import argparse
|
|
-from PIL import Image
|
|
|
|
-from torch import nn
|
|
|
|
from torchvision import transforms
|
|
from torchvision import transforms
|
|
from watermark_codec import ModelDecoder
|
|
from watermark_codec import ModelDecoder
|
|
|
|
|
|
from block import secret_get
|
|
from block import secret_get
|
|
|
|
+from block.dataset_get import CustomDataset
|
|
|
|
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
# -------------------------------------------------------------------------------------------------------------------- #
|
|
parser = argparse.ArgumentParser(description='|pt模型推理|')
|
|
parser = argparse.ArgumentParser(description='|pt模型推理|')
|
|
parser.add_argument('--model_path', default='./checkpoints/Alexnet/wm_embed/best.pt', type=str, help='|pt模型位置|')
|
|
parser.add_argument('--model_path', default='./checkpoints/Alexnet/wm_embed/best.pt', type=str, help='|pt模型位置|')
|
|
parser.add_argument('--key_path', default='./checkpoints/Alexnet/wm_embed/key.pt', type=str, help='|投影矩阵位置|')
|
|
parser.add_argument('--key_path', default='./checkpoints/Alexnet/wm_embed/key.pt', type=str, help='|投影矩阵位置|')
|
|
parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
|
|
parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
|
|
|
|
+parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
|
|
parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
|
|
parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
|
|
parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
|
|
parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
|
|
parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
|
|
parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
|
|
@@ -54,7 +53,7 @@ def predict_pt(args):
|
|
transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
transforms.ToTensor(), # 将图像转换为PyTorch张量
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
|
|
])
|
|
])
|
|
- dataset = CustomDataset(data_dir=args.data_path, transform=transform)
|
|
|
|
|
|
+ dataset = CustomDataset(data_dir=args.data_path, image_size=(args.input_size, args.input_size), transform=transform)
|
|
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
|
|
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
|
|
shuffle=False, drop_last=False, pin_memory=False,
|
|
shuffle=False, drop_last=False, pin_memory=False,
|
|
num_workers=args.num_worker)
|
|
num_workers=args.num_worker)
|
|
@@ -76,47 +75,5 @@ def predict_pt(args):
|
|
print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
|
|
print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
|
|
|
|
|
|
|
|
|
|
-class CustomDataset(torch.utils.data.Dataset):
|
|
|
|
- """
|
|
|
|
- 自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- def __init__(self, data_dir, image_size=(32, 32), transform=None):
|
|
|
|
- self.data_dir = data_dir
|
|
|
|
- self.image_size = image_size
|
|
|
|
- self.transform = transform
|
|
|
|
-
|
|
|
|
- self.images = []
|
|
|
|
- self.labels = []
|
|
|
|
-
|
|
|
|
- # 遍历指定目录下的子目录,每个子目录代表一个类别
|
|
|
|
- class_dirs = sorted(os.listdir(data_dir))
|
|
|
|
- for index, class_dir in enumerate(class_dirs):
|
|
|
|
- class_path = os.path.join(data_dir, class_dir)
|
|
|
|
-
|
|
|
|
- # 遍历当前类别目录下的图像文件
|
|
|
|
- for image_file in os.listdir(class_path):
|
|
|
|
- image_path = os.path.join(class_path, image_file)
|
|
|
|
-
|
|
|
|
- # 使用PIL加载图像并调整大小
|
|
|
|
- image = Image.open(image_path).convert('RGB')
|
|
|
|
- image = image.resize(image_size)
|
|
|
|
-
|
|
|
|
- self.images.append(np.array(image))
|
|
|
|
- self.labels.append(index)
|
|
|
|
-
|
|
|
|
- def __len__(self):
|
|
|
|
- return len(self.images)
|
|
|
|
-
|
|
|
|
- def __getitem__(self, idx):
|
|
|
|
- image = self.images[idx]
|
|
|
|
- label = self.labels[idx]
|
|
|
|
-
|
|
|
|
- if self.transform:
|
|
|
|
- image = self.transform(Image.fromarray(image))
|
|
|
|
-
|
|
|
|
- return image, label
|
|
|
|
-
|
|
|
|
-
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
predict_pt(args)
|
|
predict_pt(args)
|