1from abc import ABC, abstractmethod
2from typing import Any, Dict, Sequence
3
4from lightning import Callback, LightningModule, Trainer
5
6from ..base import LightningIRModule
7
8
[docs]
9class LambdaWarmupScheduler(ABC):
[docs]
10 def __init__(
11 self,
12 num_warmup_steps: int,
13 num_delay_steps: int = 0,
14 *args,
15 **kwargs,
16 ) -> None:
17 self.num_warmup_steps = num_warmup_steps
18 self.num_delay_steps = num_delay_steps
19 super().__init__(*args, **kwargs)
20
21 @abstractmethod
22 def value_lambda(self, current_step: int) -> float: ...
23
24 def check_delay(self, current_step: int) -> bool:
25 return current_step < self.num_delay_steps
26
27 def check_warmup(self, current_step: int) -> bool:
28 return current_step < self.num_warmup_steps + self.num_delay_steps
29
30
[docs]
31class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler):
32
[docs]
33 def __init__(
34 self,
35 num_warmup_steps: int,
36 num_training_steps: int,
37 final_value: float = 0.0,
38 num_delay_steps: int = 0,
39 *args,
40 **kwargs,
41 ) -> None:
42 self.num_training_steps = num_training_steps
43 self.final_value = final_value
44 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
45
46 def value_lambda(self, current_step: int) -> float:
47 if self.check_delay(current_step):
48 return 0.0
49 if self.check_warmup(current_step):
50 return (current_step - self.num_delay_steps) / self.num_warmup_steps
51 current_step = current_step - self.num_delay_steps - self.num_warmup_steps
52 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps
53 step_size = (1 - self.final_value) / remaining_steps
54 return max(self.final_value, 1 - step_size * current_step)
55
56
[docs]
57class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
58 def value_lambda(self, current_step: int) -> float:
59 if self.check_delay(current_step):
60 return 0.0
61 if self.check_warmup(current_step):
62 return (current_step - self.num_delay_steps) / self.num_warmup_steps
63 return 1.0
64
65
[docs]
66class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
67 def value_lambda(self, current_step: int) -> float:
68 if self.check_delay(current_step):
69 return 0.0
70 if self.check_warmup(current_step):
71 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2
72 return 1.0
73
74
[docs]
75class GenericScheduler(Callback, ABC):
76
[docs]
77 def __init__(self, keys: Sequence[str], *args, **kwargs) -> None:
78 super().__init__(*args, **kwargs)
79 self.keys = keys
80 self.values: Dict[str, float] = {}
81
82 @abstractmethod
83 def step(self, key: str, current_step: int) -> float: ...
84
85 def get_value(self, sub_keys: Sequence[str], obj: object) -> object:
86 for sub_key in sub_keys:
87 try:
88 obj = obj[int(sub_key)]
89 except ValueError:
90 obj = getattr(obj, sub_key)
91 return obj
92
93 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None:
94 obj = self.get_value(sub_keys[:-1], obj)
95 setattr(obj, sub_keys[-1], value)
96
97 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
98 for key in self.keys:
99 sub_keys = key.split(".")
100 self.values[key] = float(self.get_value(sub_keys, pl_module))
101
102 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
103 step = trainer.global_step + 1
104 for key in self.keys:
105 value = self.step(key, step)
106 sub_keys = key.split(".")
107 self.set_value(sub_keys, pl_module, value)
108
109 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
110 for key in self.keys:
111 value = self.values[key]
112 sub_keys = key.split(".")
113 self.set_value(sub_keys, pl_module, value)
114
115
[docs]
116class GenericLinearSchedulerWithLinearWarmup(GenericScheduler, LinearSchedulerWithLinearWarmup):
117 def step(self, key: str, current_step: int) -> float:
118 value = self.values[key]
119 return value * self.value_lambda(current_step)
120
121
[docs]
122class GenericConstantSchedulerWithLinearWarmup(GenericScheduler, ConstantSchedulerWithLinearWarmup):
123 def step(self, key: str, current_step: int) -> float:
124 value = self.values[key]
125 return value * self.value_lambda(current_step)
126
127
[docs]
128class GenericConstantSchedulerWithQuadraticWarmup(GenericScheduler, ConstantSchedulerWithQuadraticWarmup):
129 def step(self, key: str, current_step: int) -> float:
130 value = self.values[key]
131 return value * self.value_lambda(current_step)