12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import math
- import torch
- def adam(regularization, r_value, param, lr, betas):
- if regularization == 'L2':
- optimizer = torch.optim.Adam(param, lr=lr, betas=betas, weight_decay=r_value)
- else:
- optimizer = torch.optim.Adam(param, lr=lr, betas=betas)
- return optimizer
- class lr_adjust:
- def __init__(self, args, step_epoch, epoch_finished):
- print(f"初始化 lr_end_epoch: {args.lr_end_epoch}, step_epoch: {step_epoch}")
- self.lr_start = args.lr_start # 初始学习率
- self.lr_end = args.lr_end_ratio * args.lr_start # 最终学习率
- self.lr_end_epoch = args.lr_end_epoch # 最终学习率达到的轮数
- self.step_all = self.lr_end_epoch * step_epoch # 总调整步数
- if self.step_all == 0:
- raise ValueError("计算总调整步数时出错: step_all 不能为0, 请检查 lr_end_epoch 和 step_epoch 的值。")
- self.step_finished = epoch_finished * step_epoch # 已调整步数
- self.warmup_step = max(5, int(args.warmup_ratio * self.step_all)) # 预热训练步数
- print(f"总调整步数 step_all: {self.step_all}") # 这里将显示 step_all 的值
- def __call__(self, optimizer):
- self.step_finished += 1
- step_now = self.step_finished
- if self.step_all == 0:
- raise ValueError("调用时出错:step_all 不能为0。")
- decay = step_now / self.step_all
- lr = self.lr_end + (self.lr_start - self.lr_end) * math.cos(math.pi / 2 * decay)
- if step_now <= self.warmup_step:
- lr = lr * (0.1 + 0.9 * step_now / self.warmup_step)
- lr = max(lr, 0.000001)
- for i in range(len(optimizer.param_groups)):
- optimizer.param_groups[i]['lr'] = lr
- return optimizer
- # 示例参数,假设的一些值
- class Args:
- def __init__(self):
- self.lr_start = 0.001
- self.lr_end_ratio = 0.1
- self.lr_end_epoch = 10
- self.warmup_ratio = 0.1
- if __name__ == "__main__":
- # 伪造一些输入参数和初始状态
- args = Args()
- step_epoch = 100 # 假设每个epoch有100步
- epoch_finished = 0 # 假设从第0 epoch开始
- # 初始化调整器
- lr_adjuster = lr_adjust(args, step_epoch, epoch_finished)
- # 创建一个假的优化器
- params = [torch.randn(10, 10, requires_grad=True)]
- optimizer = adam('L2', 0.01, params, args.lr_start, (0.9, 0.999))
- # 调用lr_adjuster来调整学习率
- optimizer = lr_adjuster(optimizer)
- print(f"调整后的学习率: {optimizer.param_groups[0]['lr']}")
|