model_ema.py 955 B

12345678910111213141516171819202122232425262728
  1. import math
  2. import copy
  3. import torch
  4. class model_ema:
  5. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  6. self.ema = copy.deepcopy(self._get_model(model)).eval() # FP32 EMA
  7. self.updates = updates
  8. self.decay = lambda x: decay * (1 - math.exp(-x / tau))
  9. for p in self.ema.parameters():
  10. p.requires_grad_(False)
  11. def update(self, model):
  12. with torch.no_grad():
  13. self.updates += 1
  14. d = self.decay(self.updates)
  15. state_dict = self._get_model(model).state_dict()
  16. for k, v in self.ema.state_dict().items():
  17. if v.dtype.is_floating_point:
  18. v *= d
  19. v += (1 - d) * state_dict[k].detach()
  20. def _get_model(self, model):
  21. if type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel):
  22. return model.module
  23. else:
  24. return model