1import os
2import sys
3from pathlib import Path
4from typing import Any, Dict, List, Mapping, Set
5
6import torch
7from lightning import LightningDataModule, LightningModule, Trainer
8from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
9from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
10from lightning.pytorch.loggers import WandbLogger
11from typing_extensions import override
12
13import lightning_ir # noqa: F401
14from lightning_ir.schedulers.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
15
16if torch.cuda.is_available():
17 torch.set_float32_matmul_precision("medium")
18
19sys.path.append(str(Path.cwd()))
20
21os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
23
[docs]
24class LightningIRSaveConfigCallback(SaveConfigCallback):
25 @override
26 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
27 if stage != "fit" or trainer.logger is None:
28 return
29 return super().setup(trainer, pl_module, stage)
30
31
[docs]
32class LightningIRWandbLogger(WandbLogger):
33 @property
34 def save_dir(self) -> str | None:
35 """Gets the save directory.
36
37 Returns:
38 The path to the save directory.
39
40 """
41 if isinstance(self.experiment, DummyExperiment):
42 return None
43 return self.experiment.dir
44
45
[docs]
46class LightningIRTrainer(Trainer):
47 # TODO check that correct callbacks are registered for each subcommand
48
[docs]
49 def index(
50 self,
51 model: LightningModule | None = None,
52 dataloaders: Any | LightningDataModule | None = None,
53 ckpt_path: str | Path | None = None,
54 verbose: bool = True,
55 datamodule: LightningDataModule | None = None,
56 ) -> List[Mapping[str, float]]:
57 """Index a collection of documents."""
58 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
59
[docs]
60 def search(
61 self,
62 model: LightningModule | None = None,
63 dataloaders: Any | LightningDataModule | None = None,
64 ckpt_path: str | Path | None = None,
65 verbose: bool = True,
66 datamodule: LightningDataModule | None = None,
67 ) -> List[Mapping[str, float]]:
68 """Search for relevant documents."""
69 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
70
[docs]
71 def re_rank(
72 self,
73 model: LightningModule | None = None,
74 dataloaders: Any | LightningDataModule | None = None,
75 ckpt_path: str | Path | None = None,
76 verbose: bool = True,
77 datamodule: LightningDataModule | None = None,
78 ) -> List[Mapping[str, float]]:
79 """Re-rank a set of retrieved documents."""
80 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
81
82
[docs]
83class LightningIRCLI(LightningCLI):
84 @staticmethod
85 def configure_optimizers(
86 lightning_module: LightningModule,
87 optimizer: torch.optim.Optimizer,
88 lr_scheduler: WarmupLRScheduler | None = None,
89 ) -> Any:
90 if lr_scheduler is None:
91 return optimizer
92
93 return [optimizer], [{"scheduler": lr_scheduler, "interval": lr_scheduler.interval}]
94
95 def add_arguments_to_parser(self, parser):
96 parser.add_lr_scheduler_args(tuple(LR_SCHEDULERS))
97 parser.link_arguments("model.init_args.model_name_or_path", "data.init_args.model_name_or_path")
98 parser.link_arguments("model.init_args.config", "data.init_args.config")
99 parser.link_arguments("trainer.max_steps", "lr_scheduler.init_args.num_training_steps")
100
101 @staticmethod
102 def subcommands() -> Dict[str, Set[str]]:
103 return {
104 "fit": LightningCLI.subcommands()["fit"],
105 "index": {"model", "dataloaders", "datamodule"},
106 "search": {"model", "dataloaders", "datamodule"},
107 "re_rank": {"model", "dataloaders", "datamodule"},
108 }
109
110 def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None:
111 import warnings
112
113 with warnings.catch_warnings():
114 warnings.simplefilter("ignore")
115 return super()._add_configure_optimizers_method_to_model(subcommand)
116
117
[docs]
118def main():
119 """
120 generate config using `python main.py fit --print_config > config.yaml`
121 additional callbacks at:
122 https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
123
124 Example:
125 To obtain a default config:
126
127 python main.py fit \
128 --trainer.callbacks=ModelCheckpoint \
129 --optimizer AdamW \
130 --trainer.logger LightningIRWandbLogger \
131 --print_config > default.yaml
132
133 To run with the default config:
134
135 python main.py fit \
136 --config default.yaml
137
138 """
139 LightningIRCLI(
140 trainer_class=LightningIRTrainer,
141 save_config_callback=LightningIRSaveConfigCallback,
142 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True},
143 )
144
145
146if __name__ == "__main__":
147 main()