sampler.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import math
  2. import torch
  3. import torch.distributed as dist
  4. class RASampler(torch.utils.data.Sampler):
  5. """Sampler that restricts data loading to a subset of the dataset for distributed,
  6. with repeated augmentation.
  7. It ensures that different each augmented version of a sample will be visible to a
  8. different process (GPU).
  9. Heavily based on 'torch.utils.data.DistributedSampler'.
  10. This is borrowed from the DeiT Repo:
  11. https://github.com/facebookresearch/deit/blob/main/samplers.py
  12. """
  13. def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
  14. if num_replicas is None:
  15. if not dist.is_available():
  16. raise RuntimeError("Requires distributed package to be available!")
  17. num_replicas = dist.get_world_size()
  18. if rank is None:
  19. if not dist.is_available():
  20. raise RuntimeError("Requires distributed package to be available!")
  21. rank = dist.get_rank()
  22. self.dataset = dataset
  23. self.num_replicas = num_replicas
  24. self.rank = rank
  25. self.epoch = 0
  26. self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
  27. self.total_size = self.num_samples * self.num_replicas
  28. self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
  29. self.shuffle = shuffle
  30. self.seed = seed
  31. self.repetitions = repetitions
  32. def __iter__(self):
  33. if self.shuffle:
  34. # Deterministically shuffle based on epoch
  35. g = torch.Generator()
  36. g.manual_seed(self.seed + self.epoch)
  37. indices = torch.randperm(len(self.dataset), generator=g).tolist()
  38. else:
  39. indices = list(range(len(self.dataset)))
  40. # Add extra samples to make it evenly divisible
  41. indices = [ele for ele in indices for i in range(self.repetitions)]
  42. indices += indices[: (self.total_size - len(indices))]
  43. assert len(indices) == self.total_size
  44. # Subsample
  45. indices = indices[self.rank : self.total_size : self.num_replicas]
  46. assert len(indices) == self.num_samples
  47. return iter(indices[: self.num_selected_samples])
  48. def __len__(self):
  49. return self.num_selected_samples
  50. def set_epoch(self, epoch):
  51. self.epoch = epoch