123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import os
- import numpy as np
- import torch
- from PIL import Image
- class CustomDataset(torch.utils.data.Dataset):
- def __init__(self, data_dir, image_size=(32, 32), transform=None):
- self.data_dir = data_dir
- self.image_size = image_size
- self.transform = transform
- self.images = []
- self.labels = []
- # 遍历指定目录下的子目录,每个子目录代表一个类别
- class_dirs = sorted(os.listdir(data_dir))
- for index, class_dir in enumerate(class_dirs):
- class_path = os.path.join(data_dir, class_dir)
- # 遍历当前类别目录下的图像文件
- for image_file in os.listdir(class_path):
- image_path = os.path.join(class_path, image_file)
- # 使用PIL加载图像并调整大小
- image = Image.open(image_path).convert('RGB')
- image = self.resize_and_pad(image, self.image_size, (0, 0, 0))
- self.images.append(np.array(image))
- self.labels.append(index)
- def __len__(self):
- return len(self.images)
- def __getitem__(self, idx):
- image = self.images[idx]
- label = self.labels[idx]
- if self.transform:
- image = self.transform(Image.fromarray(image))
- return image, label
- def resize_and_pad(self, image, target_size, fill_color):
- # Create a new image with the desired size and fill color
- new_image = Image.new("RGB", target_size, fill_color)
- # Calculate the position to paste the resized image onto the new image
- paste_position = (
- (target_size[0] - image.size[0]) // 2,
- (target_size[1] - image.size[1]) // 2
- )
- # Paste the resized image onto the new image
- new_image.paste(image, paste_position)
- return new_image
- # class CustomDataset(torch.utils.data.Dataset):
- # def __init__(self, data_dir, image_size=(32, 32), transform=None):
- # self.data_dir = data_dir
- # self.image_size = image_size
- # self.transform = transform
- #
- # self.image_paths = []
- # self.labels = []
- #
- # # 遍历指定目录下的子目录,每个子目录代表一个类别
- # class_dirs = sorted(os.listdir(data_dir))
- # for index, class_dir in enumerate(class_dirs):
- # class_path = os.path.join(data_dir, class_dir)
- #
- # # 遍历当前类别目录下的图像文件
- # for image_file in os.listdir(class_path):
- # image_path = os.path.join(class_path, image_file)
- # self.image_paths.append(image_path)
- # self.labels.append(index)
- #
- # def __len__(self):
- # return len(self.image_paths)
- #
- # def __getitem__(self, idx):
- # # step 1 遍历每个类别的图片,每个水印都需要在所有类别选择5%的图片进行水印嵌入,并修改其标签为水印索引
- # # step 2 编写函数,判断指定index是否需要处理
- # # step 3 指定图片嵌入二维码,并修改标签为嵌入水印索引
- # image_path = self.image_paths[idx]
- # label = self.labels[idx]
- # # 使用PIL加载图像并调整大小
- # image = Image.open(image_path).convert('RGB')
- # image = self.resize_and_pad(image, self.image_size, (0, 0, 0))
- # image= np.array(image)
- # if self.transform:
- # image = self.transform(Image.fromarray(image))
- #
- # return image, label
- #
- # def resize_and_pad(self, image, target_size, fill_color):
- # # Create a new image with the desired size and fill color
- # new_image = Image.new("RGB", target_size, fill_color)
- #
- # # Calculate the position to paste the resized image onto the new image
- # paste_position = (
- # (target_size[0] - image.size[0]) // 2,
- # (target_size[1] - image.size[1]) // 2
- # )
- #
- # # Paste the resized image onto the new image
- # new_image.paste(image, paste_position)
- # return new_image
|