transforms.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import math
  2. from typing import Tuple
  3. import torch
  4. from presets import get_module
  5. from torch import Tensor
  6. from torchvision.transforms import functional as F
  7. def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
  8. transforms_module = get_module(use_v2)
  9. mixup_cutmix = []
  10. if mixup_alpha > 0:
  11. mixup_cutmix.append(
  12. transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
  13. if use_v2
  14. else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
  15. )
  16. if cutmix_alpha > 0:
  17. mixup_cutmix.append(
  18. transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes)
  19. if use_v2
  20. else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha)
  21. )
  22. if not mixup_cutmix:
  23. return None
  24. return transforms_module.RandomChoice(mixup_cutmix)
  25. class RandomMixUp(torch.nn.Module):
  26. """Randomly apply MixUp to the provided batch and targets.
  27. The class implements the data augmentations as described in the paper
  28. `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
  29. Args:
  30. num_classes (int): number of classes used for one-hot encoding.
  31. p (float): probability of the batch being transformed. Default value is 0.5.
  32. alpha (float): hyperparameter of the Beta distribution used for mixup.
  33. Default value is 1.0.
  34. inplace (bool): boolean to make this transform inplace. Default set to False.
  35. """
  36. def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
  37. super().__init__()
  38. if num_classes < 1:
  39. raise ValueError(
  40. f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
  41. )
  42. if alpha <= 0:
  43. raise ValueError("Alpha param can't be zero.")
  44. self.num_classes = num_classes
  45. self.p = p
  46. self.alpha = alpha
  47. self.inplace = inplace
  48. def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
  49. """
  50. Args:
  51. batch (Tensor): Float tensor of size (B, C, H, W)
  52. target (Tensor): Integer tensor of size (B, )
  53. Returns:
  54. Tensor: Randomly transformed batch.
  55. """
  56. if batch.ndim != 4:
  57. raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
  58. if target.ndim != 1:
  59. raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
  60. if not batch.is_floating_point():
  61. raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
  62. if target.dtype != torch.int64:
  63. raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
  64. if not self.inplace:
  65. batch = batch.clone()
  66. target = target.clone()
  67. if target.ndim == 1:
  68. target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
  69. if torch.rand(1).item() >= self.p:
  70. return batch, target
  71. # It's faster to roll the batch by one instead of shuffling it to create image pairs
  72. batch_rolled = batch.roll(1, 0)
  73. target_rolled = target.roll(1, 0)
  74. # Implemented as on mixup paper, page 3.
  75. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
  76. batch_rolled.mul_(1.0 - lambda_param)
  77. batch.mul_(lambda_param).add_(batch_rolled)
  78. target_rolled.mul_(1.0 - lambda_param)
  79. target.mul_(lambda_param).add_(target_rolled)
  80. return batch, target
  81. def __repr__(self) -> str:
  82. s = (
  83. f"{self.__class__.__name__}("
  84. f"num_classes={self.num_classes}"
  85. f", p={self.p}"
  86. f", alpha={self.alpha}"
  87. f", inplace={self.inplace}"
  88. f")"
  89. )
  90. return s
  91. class RandomCutMix(torch.nn.Module):
  92. """Randomly apply CutMix to the provided batch and targets.
  93. The class implements the data augmentations as described in the paper
  94. `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
  95. <https://arxiv.org/abs/1905.04899>`_.
  96. Args:
  97. num_classes (int): number of classes used for one-hot encoding.
  98. p (float): probability of the batch being transformed. Default value is 0.5.
  99. alpha (float): hyperparameter of the Beta distribution used for cutmix.
  100. Default value is 1.0.
  101. inplace (bool): boolean to make this transform inplace. Default set to False.
  102. """
  103. def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
  104. super().__init__()
  105. if num_classes < 1:
  106. raise ValueError("Please provide a valid positive value for the num_classes.")
  107. if alpha <= 0:
  108. raise ValueError("Alpha param can't be zero.")
  109. self.num_classes = num_classes
  110. self.p = p
  111. self.alpha = alpha
  112. self.inplace = inplace
  113. def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
  114. """
  115. Args:
  116. batch (Tensor): Float tensor of size (B, C, H, W)
  117. target (Tensor): Integer tensor of size (B, )
  118. Returns:
  119. Tensor: Randomly transformed batch.
  120. """
  121. if batch.ndim != 4:
  122. raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
  123. if target.ndim != 1:
  124. raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
  125. if not batch.is_floating_point():
  126. raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
  127. if target.dtype != torch.int64:
  128. raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
  129. if not self.inplace:
  130. batch = batch.clone()
  131. target = target.clone()
  132. if target.ndim == 1:
  133. target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
  134. if torch.rand(1).item() >= self.p:
  135. return batch, target
  136. # It's faster to roll the batch by one instead of shuffling it to create image pairs
  137. batch_rolled = batch.roll(1, 0)
  138. target_rolled = target.roll(1, 0)
  139. # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
  140. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
  141. _, H, W = F.get_dimensions(batch)
  142. r_x = torch.randint(W, (1,))
  143. r_y = torch.randint(H, (1,))
  144. r = 0.5 * math.sqrt(1.0 - lambda_param)
  145. r_w_half = int(r * W)
  146. r_h_half = int(r * H)
  147. x1 = int(torch.clamp(r_x - r_w_half, min=0))
  148. y1 = int(torch.clamp(r_y - r_h_half, min=0))
  149. x2 = int(torch.clamp(r_x + r_w_half, max=W))
  150. y2 = int(torch.clamp(r_y + r_h_half, max=H))
  151. batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
  152. lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
  153. target_rolled.mul_(1.0 - lambda_param)
  154. target.mul_(lambda_param).add_(target_rolled)
  155. return batch, target
  156. def __repr__(self) -> str:
  157. s = (
  158. f"{self.__class__.__name__}("
  159. f"num_classes={self.num_classes}"
  160. f", p={self.p}"
  161. f", alpha={self.alpha}"
  162. f", inplace={self.inplace}"
  163. f")"
  164. )
  165. return s