lr_get.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import math
  2. import torch
  3. def adam(regularization, r_value, param, lr, betas):
  4. if regularization == 'L2':
  5. optimizer = torch.optim.Adam(param, lr=lr, betas=betas, weight_decay=r_value)
  6. else:
  7. optimizer = torch.optim.Adam(param, lr=lr, betas=betas)
  8. return optimizer
  9. class lr_adjust:
  10. def __init__(self, args, step_epoch, epoch_finished):
  11. print(f"初始化 lr_end_epoch: {args.lr_end_epoch}, step_epoch: {step_epoch}")
  12. self.lr_start = args.lr_start # 初始学习率
  13. self.lr_end = args.lr_end_ratio * args.lr_start # 最终学习率
  14. self.lr_end_epoch = args.lr_end_epoch # 最终学习率达到的轮数
  15. self.step_all = self.lr_end_epoch * step_epoch # 总调整步数
  16. if self.step_all == 0:
  17. raise ValueError("计算总调整步数时出错: step_all 不能为0, 请检查 lr_end_epoch 和 step_epoch 的值。")
  18. self.step_finished = epoch_finished * step_epoch # 已调整步数
  19. self.warmup_step = max(5, int(args.warmup_ratio * self.step_all)) # 预热训练步数
  20. print(f"总调整步数 step_all: {self.step_all}") # 这里将显示 step_all 的值
  21. def __call__(self, optimizer):
  22. self.step_finished += 1
  23. step_now = self.step_finished
  24. if self.step_all == 0:
  25. raise ValueError("调用时出错:step_all 不能为0。")
  26. decay = step_now / self.step_all
  27. lr = self.lr_end + (self.lr_start - self.lr_end) * math.cos(math.pi / 2 * decay)
  28. if step_now <= self.warmup_step:
  29. lr = lr * (0.1 + 0.9 * step_now / self.warmup_step)
  30. lr = max(lr, 0.000001)
  31. for i in range(len(optimizer.param_groups)):
  32. optimizer.param_groups[i]['lr'] = lr
  33. return optimizer
  34. # 示例参数,假设的一些值
  35. class Args:
  36. def __init__(self):
  37. self.lr_start = 0.001
  38. self.lr_end_ratio = 0.1
  39. self.lr_end_epoch = 10
  40. self.warmup_ratio = 0.1
  41. if __name__ == "__main__":
  42. # 伪造一些输入参数和初始状态
  43. args = Args()
  44. step_epoch = 100 # 假设每个epoch有100步
  45. epoch_finished = 0 # 假设从第0 epoch开始
  46. # 初始化调整器
  47. lr_adjuster = lr_adjust(args, step_epoch, epoch_finished)
  48. # 创建一个假的优化器
  49. params = [torch.randn(10, 10, requires_grad=True)]
  50. optimizer = adam('L2', 0.01, params, args.lr_start, (0.9, 0.999))
  51. # 调用lr_adjuster来调整学习率
  52. optimizer = lr_adjuster(optimizer)
  53. print(f"调整后的学习率: {optimizer.param_groups[0]['lr']}")