dataloader.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. from torch.utils.data.dataset import Dataset
  6. from utils.utils import cvtColor, preprocess_input
  7. class FRCNNDataset(Dataset):
  8. def __init__(self, annotation_lines, input_shape = [600, 600], train = True):
  9. self.annotation_lines = annotation_lines
  10. self.length = len(annotation_lines)
  11. self.input_shape = input_shape
  12. self.train = train
  13. def __len__(self):
  14. return self.length
  15. def __getitem__(self, index):
  16. index = index % self.length
  17. #---------------------------------------------------#
  18. # 训练时进行数据的随机增强
  19. # 验证时不进行数据的随机增强
  20. #---------------------------------------------------#
  21. image, y = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)
  22. image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
  23. box_data = np.zeros((len(y), 5))
  24. if len(y) > 0:
  25. box_data[:len(y)] = y
  26. box = box_data[:, :4]
  27. label = box_data[:, -1]
  28. return image, box, label
  29. def rand(self, a=0, b=1):
  30. return np.random.rand()*(b-a) + a
  31. def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
  32. line = annotation_line.split()
  33. #------------------------------#
  34. # 读取图像并转换成RGB图像
  35. #------------------------------#
  36. image = Image.open(line[0])
  37. image = cvtColor(image)
  38. #------------------------------#
  39. # 获得图像的高宽与目标高宽
  40. #------------------------------#
  41. iw, ih = image.size
  42. h, w = input_shape
  43. #------------------------------#
  44. # 获得预测框
  45. #------------------------------#
  46. box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
  47. if not random:
  48. scale = min(w/iw, h/ih)
  49. nw = int(iw*scale)
  50. nh = int(ih*scale)
  51. dx = (w-nw)//2
  52. dy = (h-nh)//2
  53. #---------------------------------#
  54. # 将图像多余的部分加上灰条
  55. #---------------------------------#
  56. image = image.resize((nw,nh), Image.BICUBIC)
  57. new_image = Image.new('RGB', (w,h), (128,128,128))
  58. new_image.paste(image, (dx, dy))
  59. image_data = np.array(new_image, np.float32)
  60. #---------------------------------#
  61. # 对真实框进行调整
  62. #---------------------------------#
  63. if len(box)>0:
  64. np.random.shuffle(box)
  65. box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
  66. box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
  67. box[:, 0:2][box[:, 0:2]<0] = 0
  68. box[:, 2][box[:, 2]>w] = w
  69. box[:, 3][box[:, 3]>h] = h
  70. box_w = box[:, 2] - box[:, 0]
  71. box_h = box[:, 3] - box[:, 1]
  72. box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
  73. return image_data, box
  74. #------------------------------------------#
  75. # 对图像进行缩放并且进行长和宽的扭曲
  76. #------------------------------------------#
  77. new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
  78. scale = self.rand(.25, 2)
  79. if new_ar < 1:
  80. nh = int(scale*h)
  81. nw = int(nh*new_ar)
  82. else:
  83. nw = int(scale*w)
  84. nh = int(nw/new_ar)
  85. image = image.resize((nw,nh), Image.BICUBIC)
  86. #------------------------------------------#
  87. # 将图像多余的部分加上灰条
  88. #------------------------------------------#
  89. dx = int(self.rand(0, w-nw))
  90. dy = int(self.rand(0, h-nh))
  91. new_image = Image.new('RGB', (w,h), (128,128,128))
  92. new_image.paste(image, (dx, dy))
  93. image = new_image
  94. #------------------------------------------#
  95. # 翻转图像
  96. #------------------------------------------#
  97. flip = self.rand()<.5
  98. if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
  99. image_data = np.array(image, np.uint8)
  100. #---------------------------------#
  101. # 对图像进行色域变换
  102. # 计算色域变换的参数
  103. #---------------------------------#
  104. r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
  105. #---------------------------------#
  106. # 将图像转到HSV上
  107. #---------------------------------#
  108. hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
  109. dtype = image_data.dtype
  110. #---------------------------------#
  111. # 应用变换
  112. #---------------------------------#
  113. x = np.arange(0, 256, dtype=r.dtype)
  114. lut_hue = ((x * r[0]) % 180).astype(dtype)
  115. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  116. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  117. image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  118. image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
  119. #---------------------------------#
  120. # 对真实框进行调整
  121. #---------------------------------#
  122. if len(box)>0:
  123. np.random.shuffle(box)
  124. box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
  125. box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
  126. if flip: box[:, [0,2]] = w - box[:, [2,0]]
  127. box[:, 0:2][box[:, 0:2]<0] = 0
  128. box[:, 2][box[:, 2]>w] = w
  129. box[:, 3][box[:, 3]>h] = h
  130. box_w = box[:, 2] - box[:, 0]
  131. box_h = box[:, 3] - box[:, 1]
  132. box = box[np.logical_and(box_w>1, box_h>1)]
  133. return image_data, box
  134. # DataLoader中collate_fn使用
  135. def frcnn_dataset_collate(batch):
  136. images = []
  137. bboxes = []
  138. labels = []
  139. for img, box, label in batch:
  140. images.append(img)
  141. bboxes.append(box)
  142. labels.append(label)
  143. images = torch.from_numpy(np.array(images))
  144. return images, bboxes, labels