Browse Source

修改训练预测数据集引用

liyan 1 year ago
parent
commit
dc0ab5fffc
5 changed files with 70 additions and 177 deletions
  1. 57 0
      block/dataset_get.py
  2. 4 43
      block/train_get.py
  3. 3 43
      block/train_with_watermark.py
  4. 3 45
      predict_pt.py
  5. 3 46
      predict_pt_embed.py

+ 57 - 0
block/dataset_get.py

@@ -0,0 +1,57 @@
+import os
+
+import numpy as np
+import torch
+from PIL import Image
+
+
+class CustomDataset(torch.utils.data.Dataset):
+    def __init__(self, data_dir, image_size=(32, 32), transform=None):
+        self.data_dir = data_dir
+        self.image_size = image_size
+        self.transform = transform
+
+        self.images = []
+        self.labels = []
+
+        # 遍历指定目录下的子目录,每个子目录代表一个类别
+        class_dirs = sorted(os.listdir(data_dir))
+        for index, class_dir in enumerate(class_dirs):
+            class_path = os.path.join(data_dir, class_dir)
+
+            # 遍历当前类别目录下的图像文件
+            for image_file in os.listdir(class_path):
+                image_path = os.path.join(class_path, image_file)
+
+                # 使用PIL加载图像并调整大小
+                image = Image.open(image_path).convert('RGB')
+                image = self.resize_and_pad(image, self.image_size, (0, 0, 0))
+
+                self.images.append(np.array(image))
+                self.labels.append(index)
+
+    def __len__(self):
+        return len(self.images)
+
+    def __getitem__(self, idx):
+        image = self.images[idx]
+        label = self.labels[idx]
+
+        if self.transform:
+            image = self.transform(Image.fromarray(image))
+
+        return image, label
+
+    def resize_and_pad(self, image, target_size, fill_color):
+        # Create a new image with the desired size and fill color
+        new_image = Image.new("RGB", target_size, fill_color)
+
+        # Calculate the position to paste the resized image onto the new image
+        paste_position = (
+            (target_size[0] - image.size[0]) // 2,
+            (target_size[1] - image.size[1]) // 2
+        )
+
+        # Paste the resized image onto the new image
+        new_image.paste(image, paste_position)
+        return new_image

+ 4 - 43
block/train_get.py

@@ -1,12 +1,11 @@
-import os
-
 import cv2
 import tqdm
 import wandb
 import torch
 import numpy as np
-from PIL import Image
 from torchvision import transforms
+
+from block.dataset_get import CustomDataset
 from block.val_get import val_get
 from block.model_ema import model_ema
 from block.lr_get import adam, lr_adjust
@@ -26,7 +25,7 @@ def train_get(args, model_dict, loss):
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
     ])
-    train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
+    train_dataset = CustomDataset(data_dir=args.train_dir, image_size=(args.input_size, args.input_size), transform=train_transform)
     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
     train_shuffle = False if args.distributed else True  # 分布式设置sampler后shuffle要为False
     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
@@ -37,7 +36,7 @@ def train_get(args, model_dict, loss):
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
     ])
-    val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
+    val_dataset = CustomDataset(data_dir=args.test_dir, image_size=(args.input_size, args.input_size), transform=val_transform)
     val_sampler = None  # 分布式时数据合在主GPU上进行验证
     val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
@@ -153,41 +152,3 @@ def train_get(args, model_dict, loss):
                                   })
                 args.wandb_run.log(wandb_log)
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
-
-
-class CustomDataset(torch.utils.data.Dataset):
-    def __init__(self, data_dir, image_size=(32, 32), transform=None):
-        self.data_dir = data_dir
-        self.image_size = image_size
-        self.transform = transform
-
-        self.images = []
-        self.labels = []
-
-        # 遍历指定目录下的子目录,每个子目录代表一个类别
-        class_dirs = sorted(os.listdir(data_dir))
-        for index, class_dir in enumerate(class_dirs):
-            class_path = os.path.join(data_dir, class_dir)
-
-            # 遍历当前类别目录下的图像文件
-            for image_file in os.listdir(class_path):
-                image_path = os.path.join(class_path, image_file)
-
-                # 使用PIL加载图像并调整大小
-                image = Image.open(image_path).convert('RGB')
-                image = image.resize(image_size)
-
-                self.images.append(np.array(image))
-                self.labels.append(index)
-
-    def __len__(self):
-        return len(self.images)
-
-    def __getitem__(self, idx):
-        image = self.images[idx]
-        label = self.labels[idx]
-
-        if self.transform:
-            image = self.transform(Image.fromarray(image))
-
-        return image, label

+ 3 - 43
block/train_with_watermark.py

@@ -1,15 +1,13 @@
-import os
-
 import cv2
 import tqdm
 import wandb
 import torch
 import numpy as np
-from PIL import Image
 from torch import nn
 from torchvision import transforms
 from watermark_codec import ModelEncoder
 
+from block.dataset_get import CustomDataset
 from block.val_get import val_get
 from block.model_ema import model_ema
 from block.lr_get import adam, lr_adjust
@@ -37,7 +35,7 @@ def train_embed(args, model_dict, loss, secret):
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
     ])
-    train_dataset = CustomDataset(data_dir=args.train_dir, transform=train_transform)
+    train_dataset = CustomDataset(data_dir=args.train_dir, image_size=(args.input_size, args.input_size), transform=train_transform)
     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
     train_shuffle = False if args.distributed else True  # 分布式设置sampler后shuffle要为False
     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=train_shuffle,
@@ -48,7 +46,7 @@ def train_embed(args, model_dict, loss, secret):
         transforms.ToTensor(),  # 将图像转换为PyTorch张量
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
     ])
-    val_dataset = CustomDataset(data_dir=args.test_dir, transform=val_transform)
+    val_dataset = CustomDataset(data_dir=args.test_dir, image_size=(args.input_size, args.input_size), transform=val_transform)
     val_sampler = None  # 分布式时数据合在主GPU上进行验证
     val_batch = args.batch // args.device_number  # 分布式验证时batch要减少为一个GPU的量
     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch, shuffle=False,
@@ -171,41 +169,3 @@ def train_embed(args, model_dict, loss, secret):
                                   })
                 args.wandb_run.log(wandb_log)
         torch.distributed.barrier() if args.distributed else None  # 分布式时每轮训练后让所有GPU进行同步,快的GPU会在此等待
-
-
-class CustomDataset(torch.utils.data.Dataset):
-    def __init__(self, data_dir, image_size=(32, 32), transform=None):
-        self.data_dir = data_dir
-        self.image_size = image_size
-        self.transform = transform
-
-        self.images = []
-        self.labels = []
-
-        # 遍历指定目录下的子目录,每个子目录代表一个类别
-        class_dirs = sorted(os.listdir(data_dir))
-        for index, class_dir in enumerate(class_dirs):
-            class_path = os.path.join(data_dir, class_dir)
-
-            # 遍历当前类别目录下的图像文件
-            for image_file in os.listdir(class_path):
-                image_path = os.path.join(class_path, image_file)
-
-                # 使用PIL加载图像并调整大小
-                image = Image.open(image_path).convert('RGB')
-                image = image.resize(image_size)
-
-                self.images.append(np.array(image))
-                self.labels.append(index)
-
-    def __len__(self):
-        return len(self.images)
-
-    def __getitem__(self, idx):
-        image = self.images[idx]
-        label = self.labels[idx]
-
-        if self.transform:
-            image = self.transform(Image.fromarray(image))
-
-        return image, label

+ 3 - 45
predict_pt.py

@@ -1,12 +1,12 @@
 import os
 import time
 
-import numpy as np
 import torch
 import argparse
-from PIL import Image
 from torchvision import transforms
 
+from block.dataset_get import CustomDataset
+
 # -------------------------------------------------------------------------------------------------------------------- #
 parser = argparse.ArgumentParser(description='|pt模型推理|')
 parser.add_argument('--model_path', default='./checkpoints/Alexnet/best.pt', type=str, help='|pt模型位置|')
@@ -41,7 +41,7 @@ def predict_pt(args):
             transforms.ToTensor(),  # 将图像转换为PyTorch张量
             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
         ])
-        dataset = CustomDataset(data_dir=args.data_path, transform=transform)
+        dataset = CustomDataset(data_dir=args.data_path, image_size=(args.input_size, args.input_size), transform=transform)
         dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
                                                  shuffle=False, drop_last=False, pin_memory=False,
                                                  num_workers=args.num_worker)
@@ -63,47 +63,5 @@ def predict_pt(args):
         print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
 
 
-class CustomDataset(torch.utils.data.Dataset):
-    """
-    自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别
-    """
-
-    def __init__(self, data_dir, image_size=(32, 32), transform=None):
-        self.data_dir = data_dir
-        self.image_size = image_size
-        self.transform = transform
-
-        self.images = []
-        self.labels = []
-
-        # 遍历指定目录下的子目录,每个子目录代表一个类别
-        class_dirs = sorted(os.listdir(data_dir))
-        for index, class_dir in enumerate(class_dirs):
-            class_path = os.path.join(data_dir, class_dir)
-
-            # 遍历当前类别目录下的图像文件
-            for image_file in os.listdir(class_path):
-                image_path = os.path.join(class_path, image_file)
-
-                # 使用PIL加载图像并调整大小
-                image = Image.open(image_path).convert('RGB')
-                image = image.resize(image_size)
-
-                self.images.append(np.array(image))
-                self.labels.append(index)
-
-    def __len__(self):
-        return len(self.images)
-
-    def __getitem__(self, idx):
-        image = self.images[idx]
-        label = self.labels[idx]
-
-        if self.transform:
-            image = self.transform(Image.fromarray(image))
-
-        return image, label
-
-
 if __name__ == '__main__':
     predict_pt(args)

+ 3 - 46
predict_pt_embed.py

@@ -1,21 +1,20 @@
 import os
 import time
 
-import numpy as np
 import torch
 import argparse
-from PIL import Image
-from torch import nn
 from torchvision import transforms
 from watermark_codec import ModelDecoder
 
 from block import secret_get
+from block.dataset_get import CustomDataset
 
 # -------------------------------------------------------------------------------------------------------------------- #
 parser = argparse.ArgumentParser(description='|pt模型推理|')
 parser.add_argument('--model_path', default='./checkpoints/Alexnet/wm_embed/best.pt', type=str, help='|pt模型位置|')
 parser.add_argument('--key_path', default='./checkpoints/Alexnet/wm_embed/key.pt', type=str, help='|投影矩阵位置|')
 parser.add_argument('--data_path', default='./dataset/CIFAR-10/test_cifar10_JPG', type=str, help='|验证集文件夹位置|')
+parser.add_argument('--input_size', default=32, type=int, help='|模型输入图片大小|')
 parser.add_argument('--batch', default=200, type=int, help='|输入图片批量|')
 parser.add_argument('--device', default='cuda', type=str, help='|推理设备|')
 parser.add_argument('--num_worker', default=0, type=int, help='|CPU处理数据的进程数,0只有一个主进程,一般为0、2、4、8|')
@@ -54,7 +53,7 @@ def predict_pt(args):
             transforms.ToTensor(),  # 将图像转换为PyTorch张量
             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 标准化
         ])
-        dataset = CustomDataset(data_dir=args.data_path, transform=transform)
+        dataset = CustomDataset(data_dir=args.data_path, image_size=(args.input_size, args.input_size), transform=transform)
         dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch,
                                                  shuffle=False, drop_last=False, pin_memory=False,
                                                  num_workers=args.num_worker)
@@ -76,47 +75,5 @@ def predict_pt(args):
         print(f'\n| 验证 | accuracy:{accuracy:.4f} | 图片总数:{total} | 每张耗时:{(end_time - start_time) / total} ')
 
 
-class CustomDataset(torch.utils.data.Dataset):
-    """
-    自定义数据集,从指定位置加载图片,并根据不同的文件夹区分图片所属类别
-    """
-
-    def __init__(self, data_dir, image_size=(32, 32), transform=None):
-        self.data_dir = data_dir
-        self.image_size = image_size
-        self.transform = transform
-
-        self.images = []
-        self.labels = []
-
-        # 遍历指定目录下的子目录,每个子目录代表一个类别
-        class_dirs = sorted(os.listdir(data_dir))
-        for index, class_dir in enumerate(class_dirs):
-            class_path = os.path.join(data_dir, class_dir)
-
-            # 遍历当前类别目录下的图像文件
-            for image_file in os.listdir(class_path):
-                image_path = os.path.join(class_path, image_file)
-
-                # 使用PIL加载图像并调整大小
-                image = Image.open(image_path).convert('RGB')
-                image = image.resize(image_size)
-
-                self.images.append(np.array(image))
-                self.labels.append(index)
-
-    def __len__(self):
-        return len(self.images)
-
-    def __getitem__(self, idx):
-        image = self.images[idx]
-        label = self.labels[idx]
-
-        if self.transform:
-            image = self.transform(Image.fromarray(image))
-
-        return image, label
-
-
 if __name__ == '__main__':
     predict_pt(args)