Source code for lightning_ir.schedulers.lr_schedulers
1import torch
2
3from .schedulers import ConstantSchedulerWithLinearWarmup, LambdaWarmupScheduler, LinearSchedulerWithLinearWarmup
4
5
[docs]
6class WarmupLRScheduler(LambdaWarmupScheduler, torch.optim.lr_scheduler.LambdaLR):
[docs]
7 def __init__(
8 self,
9 optimizer: torch.optim.Optimizer,
10 num_warmup_steps: int,
11 *args,
12 verbose: bool = False,
13 **kwargs,
14 ) -> None:
15 last_epoch = -1
16 self.interval = "step"
17 super().__init__(
18 *args,
19 optimizer=optimizer,
20 lr_lambda=self.value_lambda,
21 num_warmup_steps=num_warmup_steps,
22 last_epoch=last_epoch,
23 verbose=verbose,
24 **kwargs,
25 )
26
27
[docs]
28class LinearLRSchedulerWithLinearWarmup(WarmupLRScheduler, LinearSchedulerWithLinearWarmup):
29 pass
30
31
[docs]
32class ConstantLRSchedulerWithLinearWarmup(WarmupLRScheduler, ConstantSchedulerWithLinearWarmup):
33 pass
34
35
36LR_SCHEDULERS = (
37 LinearLRSchedulerWithLinearWarmup,
38 ConstantLRSchedulerWithLinearWarmup,
39 WarmupLRScheduler,
40 torch.optim.lr_scheduler.LRScheduler,
41)