Source code for lightning_ir.main

  1import os
  2import sys
  3from pathlib import Path
  4from typing import Any, Dict, List, Mapping, Set
  5
  6import torch
  7from lightning import LightningDataModule, LightningModule, Trainer
  8from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
  9from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
 10from lightning.pytorch.loggers import WandbLogger
 11from typing_extensions import override
 12
 13import lightning_ir  # noqa: F401
 14from lightning_ir.schedulers.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
 15
 16if torch.cuda.is_available():
 17    torch.set_float32_matmul_precision("medium")
 18
 19sys.path.append(str(Path.cwd()))
 20
 21os.environ["TOKENIZERS_PARALLELISM"] = "false"
 22
 23
[docs] 24class LightningIRSaveConfigCallback(SaveConfigCallback): 25 @override 26 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 27 if stage != "fit" or trainer.logger is None: 28 return 29 return super().setup(trainer, pl_module, stage)
30 31
[docs] 32class LightningIRWandbLogger(WandbLogger): 33 @property 34 def save_dir(self) -> str | None: 35 """Gets the save directory. 36 37 Returns: 38 The path to the save directory. 39 40 """ 41 if isinstance(self.experiment, DummyExperiment): 42 return None 43 return self.experiment.dir
44 45
[docs] 46class LightningIRTrainer(Trainer): 47 # TODO check that correct callbacks are registered for each subcommand 48
[docs] 49 def index( 50 self, 51 model: LightningModule | None = None, 52 dataloaders: Any | LightningDataModule | None = None, 53 ckpt_path: str | Path | None = None, 54 verbose: bool = True, 55 datamodule: LightningDataModule | None = None, 56 ) -> List[Mapping[str, float]]: 57 """Index a collection of documents.""" 58 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
59
[docs] 60 def search( 61 self, 62 model: LightningModule | None = None, 63 dataloaders: Any | LightningDataModule | None = None, 64 ckpt_path: str | Path | None = None, 65 verbose: bool = True, 66 datamodule: LightningDataModule | None = None, 67 ) -> List[Mapping[str, float]]: 68 """Search for relevant documents.""" 69 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
70
[docs] 71 def re_rank( 72 self, 73 model: LightningModule | None = None, 74 dataloaders: Any | LightningDataModule | None = None, 75 ckpt_path: str | Path | None = None, 76 verbose: bool = True, 77 datamodule: LightningDataModule | None = None, 78 ) -> List[Mapping[str, float]]: 79 """Re-rank a set of retrieved documents.""" 80 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
81 82
[docs] 83class LightningIRCLI(LightningCLI): 84 @staticmethod 85 def configure_optimizers( 86 lightning_module: LightningModule, 87 optimizer: torch.optim.Optimizer, 88 lr_scheduler: WarmupLRScheduler | None = None, 89 ) -> Any: 90 if lr_scheduler is None: 91 return optimizer 92 93 return [optimizer], [{"scheduler": lr_scheduler, "interval": lr_scheduler.interval}] 94 95 def add_arguments_to_parser(self, parser): 96 parser.add_lr_scheduler_args(tuple(LR_SCHEDULERS)) 97 parser.link_arguments("model.init_args.model_name_or_path", "data.init_args.model_name_or_path") 98 parser.link_arguments("model.init_args.config", "data.init_args.config") 99 parser.link_arguments("trainer.max_steps", "lr_scheduler.init_args.num_training_steps") 100 101 @staticmethod 102 def subcommands() -> Dict[str, Set[str]]: 103 return { 104 "fit": LightningCLI.subcommands()["fit"], 105 "index": {"model", "dataloaders", "datamodule"}, 106 "search": {"model", "dataloaders", "datamodule"}, 107 "re_rank": {"model", "dataloaders", "datamodule"}, 108 } 109 110 def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None: 111 import warnings 112 113 with warnings.catch_warnings(): 114 warnings.simplefilter("ignore") 115 return super()._add_configure_optimizers_method_to_model(subcommand)
116 117
[docs] 118def main(): 119 """ 120 generate config using `python main.py fit --print_config > config.yaml` 121 additional callbacks at: 122 https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks 123 124 Example: 125 To obtain a default config: 126 127 python main.py fit \ 128 --trainer.callbacks=ModelCheckpoint \ 129 --optimizer AdamW \ 130 --trainer.logger LightningIRWandbLogger \ 131 --print_config > default.yaml 132 133 To run with the default config: 134 135 python main.py fit \ 136 --config default.yaml 137 138 """ 139 LightningIRCLI( 140 trainer_class=LightningIRTrainer, 141 save_config_callback=LightningIRSaveConfigCallback, 142 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}, 143 )
144 145 146if __name__ == "__main__": 147 main()