import os import numpy as np import torch from PIL import Image class CustomDataset(torch.utils.data.Dataset): def __init__(self, data_dir, image_size=(32, 32), transform=None): self.data_dir = data_dir self.image_size = image_size self.transform = transform self.images = [] self.labels = [] # 遍历指定目录下的子目录,每个子目录代表一个类别 class_dirs = sorted(os.listdir(data_dir)) for index, class_dir in enumerate(class_dirs): class_path = os.path.join(data_dir, class_dir) # 遍历当前类别目录下的图像文件 for image_file in os.listdir(class_path): image_path = os.path.join(class_path, image_file) # 使用PIL加载图像并调整大小 image = Image.open(image_path).convert('RGB') image = self.resize_and_pad(image, self.image_size, (0, 0, 0)) self.images.append(np.array(image)) self.labels.append(index) def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] label = self.labels[idx] if self.transform: image = self.transform(Image.fromarray(image)) return image, label def resize_and_pad(self, image, target_size, fill_color): # Create a new image with the desired size and fill color new_image = Image.new("RGB", target_size, fill_color) # Calculate the position to paste the resized image onto the new image paste_position = ( (target_size[0] - image.size[0]) // 2, (target_size[1] - image.size[1]) // 2 ) # Paste the resized image onto the new image new_image.paste(image, paste_position) return new_image # class CustomDataset(torch.utils.data.Dataset): # def __init__(self, data_dir, image_size=(32, 32), transform=None): # self.data_dir = data_dir # self.image_size = image_size # self.transform = transform # # self.image_paths = [] # self.labels = [] # # # 遍历指定目录下的子目录,每个子目录代表一个类别 # class_dirs = sorted(os.listdir(data_dir)) # for index, class_dir in enumerate(class_dirs): # class_path = os.path.join(data_dir, class_dir) # # # 遍历当前类别目录下的图像文件 # for image_file in os.listdir(class_path): # image_path = os.path.join(class_path, image_file) # self.image_paths.append(image_path) # self.labels.append(index) # # def __len__(self): # return len(self.image_paths) # # def __getitem__(self, idx): # # step 1 遍历每个类别的图片,每个水印都需要在所有类别选择5%的图片进行水印嵌入,并修改其标签为水印索引 # # step 2 编写函数,判断指定index是否需要处理 # # step 3 指定图片嵌入二维码,并修改标签为嵌入水印索引 # image_path = self.image_paths[idx] # label = self.labels[idx] # # 使用PIL加载图像并调整大小 # image = Image.open(image_path).convert('RGB') # image = self.resize_and_pad(image, self.image_size, (0, 0, 0)) # image= np.array(image) # if self.transform: # image = self.transform(Image.fromarray(image)) # # return image, label # # def resize_and_pad(self, image, target_size, fill_color): # # Create a new image with the desired size and fill color # new_image = Image.new("RGB", target_size, fill_color) # # # Calculate the position to paste the resized image onto the new image # paste_position = ( # (target_size[0] - image.size[0]) // 2, # (target_size[1] - image.size[1]) // 2 # ) # # # Paste the resized image onto the new image # new_image.paste(image, paste_position) # return new_image