dataset_get.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. class CustomDataset(torch.utils.data.Dataset):
  6. def __init__(self, data_dir, image_size=(32, 32), transform=None):
  7. self.data_dir = data_dir
  8. self.image_size = image_size
  9. self.transform = transform
  10. self.images = []
  11. self.labels = []
  12. # 遍历指定目录下的子目录,每个子目录代表一个类别
  13. class_dirs = sorted(os.listdir(data_dir))
  14. for index, class_dir in enumerate(class_dirs):
  15. class_path = os.path.join(data_dir, class_dir)
  16. # 遍历当前类别目录下的图像文件
  17. for image_file in os.listdir(class_path):
  18. image_path = os.path.join(class_path, image_file)
  19. # 使用PIL加载图像并调整大小
  20. image = Image.open(image_path).convert('RGB')
  21. image = self.resize_and_pad(image, self.image_size, (0, 0, 0))
  22. self.images.append(np.array(image))
  23. self.labels.append(index)
  24. def __len__(self):
  25. return len(self.images)
  26. def __getitem__(self, idx):
  27. image = self.images[idx]
  28. label = self.labels[idx]
  29. if self.transform:
  30. image = self.transform(Image.fromarray(image))
  31. return image, label
  32. def resize_and_pad(self, image, target_size, fill_color):
  33. # Create a new image with the desired size and fill color
  34. new_image = Image.new("RGB", target_size, fill_color)
  35. # Calculate the position to paste the resized image onto the new image
  36. paste_position = (
  37. (target_size[0] - image.size[0]) // 2,
  38. (target_size[1] - image.size[1]) // 2
  39. )
  40. # Paste the resized image onto the new image
  41. new_image.paste(image, paste_position)
  42. return new_image
  43. # class CustomDataset(torch.utils.data.Dataset):
  44. # def __init__(self, data_dir, image_size=(32, 32), transform=None):
  45. # self.data_dir = data_dir
  46. # self.image_size = image_size
  47. # self.transform = transform
  48. #
  49. # self.image_paths = []
  50. # self.labels = []
  51. #
  52. # # 遍历指定目录下的子目录,每个子目录代表一个类别
  53. # class_dirs = sorted(os.listdir(data_dir))
  54. # for index, class_dir in enumerate(class_dirs):
  55. # class_path = os.path.join(data_dir, class_dir)
  56. #
  57. # # 遍历当前类别目录下的图像文件
  58. # for image_file in os.listdir(class_path):
  59. # image_path = os.path.join(class_path, image_file)
  60. # self.image_paths.append(image_path)
  61. # self.labels.append(index)
  62. #
  63. # def __len__(self):
  64. # return len(self.image_paths)
  65. #
  66. # def __getitem__(self, idx):
  67. # # step 1 遍历每个类别的图片,每个水印都需要在所有类别选择5%的图片进行水印嵌入,并修改其标签为水印索引
  68. # # step 2 编写函数,判断指定index是否需要处理
  69. # # step 3 指定图片嵌入二维码,并修改标签为嵌入水印索引
  70. # image_path = self.image_paths[idx]
  71. # label = self.labels[idx]
  72. # # 使用PIL加载图像并调整大小
  73. # image = Image.open(image_path).convert('RGB')
  74. # image = self.resize_and_pad(image, self.image_size, (0, 0, 0))
  75. # image= np.array(image)
  76. # if self.transform:
  77. # image = self.transform(Image.fromarray(image))
  78. #
  79. # return image, label
  80. #
  81. # def resize_and_pad(self, image, target_size, fill_color):
  82. # # Create a new image with the desired size and fill color
  83. # new_image = Image.new("RGB", target_size, fill_color)
  84. #
  85. # # Calculate the position to paste the resized image onto the new image
  86. # paste_position = (
  87. # (target_size[0] - image.size[0]) // 2,
  88. # (target_size[1] - image.size[1]) // 2
  89. # )
  90. #
  91. # # Paste the resized image onto the new image
  92. # new_image.paste(image, paste_position)
  93. # return new_image