Source code for lightning_ir.callbacks.callbacks

  1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
  2
  3from __future__ import annotations
  4
  5import itertools
  6from dataclasses import is_dataclass
  7from pathlib import Path
  8from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar
  9
 10import pandas as pd
 11import torch
 12from lightning import LightningModule, Trainer
 13from lightning.pytorch.callbacks import Callback, TQDMProgressBar
 14
 15from ..data import RankBatch, SearchBatch
 16from ..data.dataset import RUN_HEADER, DocDataset, QueryDataset, RunDataset
 17from ..data.ir_datasets_utils import _register_local_dataset
 18from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
 19
 20if TYPE_CHECKING:
 21    from ..base import LightningIRModule, LightningIROutput
 22    from ..bi_encoder import BiEncoderModule, BiEncoderOutput
 23
 24T = TypeVar("T")
 25
 26
 27def _format_large_number(number: float) -> str:
 28    suffixes = ["", "K", "M", "B", "T"]
 29    suffix_index = 0
 30
 31    while number >= 1000 and suffix_index < len(suffixes) - 1:
 32        number /= 1000.0
 33        suffix_index += 1
 34
 35    formatted_number = "{:.2f}".format(number)
 36
 37    suffix = suffixes[suffix_index]
 38    if suffix:
 39        formatted_number += f" {suffix}"
 40    return formatted_number
 41
 42
 43class _GatherMixin:
 44    """Mixin to gather dataclasses across all processes"""
 45
 46    def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
 47        if is_dataclass(dataclass):
 48            return dataclass.__class__(
 49                **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
 50            )
 51        return pl_module.all_gather(dataclass)
 52
 53
 54class _IndexDirMixin:
 55    """Mixin to get index_dir"""
 56
 57    index_dir: Path | str | None
 58
 59    def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
 60        index_dir = self.index_dir
 61        if index_dir is None:
 62            default_index_dir = Path(pl_module.config.name_or_path)
 63            if default_index_dir.exists():
 64                index_dir = default_index_dir / "indexes"
 65            else:
 66                raise ValueError("No index_dir provided and model_name_or_path is not a path")
 67        index_dir = Path(index_dir)
 68        index_dir = index_dir / dataset.docs_dataset_id
 69        return index_dir
 70
 71
 72class _OverwriteMixin:
 73    """Mixin to skip datasets (for indexing or searching) if they already exist"""
 74
 75    _get_save_path: Callable[[Trainer, LightningModule, int], Path]
 76
 77    def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
 78        overwrite = getattr(self, "overwrite", False)
 79        if not overwrite:
 80            datasets = list(trainer.datamodule.inference_datasets)
 81            remove_datasets = []
 82            for dataset_idx in range(len(datasets)):
 83                save_path = self._get_save_path(trainer, pl_module, dataset_idx)
 84                if save_path.exists():
 85                    remove_datasets.append(dataset_idx)
 86                    trainer.print(
 87                        f"`{save_path}` already exists. Skipping this dataset. Set overwrite=True to overwrite"
 88                    )
 89            for dataset_idx in remove_datasets[::-1]:
 90                del trainer.datamodule.inference_datasets[dataset_idx]
 91
 92
[docs] 93class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs] 94 def __init__( 95 self, 96 index_config: IndexConfig, 97 index_dir: Path | str | None = None, 98 overwrite: bool = False, 99 verbose: bool = False, 100 ) -> None: 101 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`. 102 103 :param index_config: Configuration for the indexer 104 :type index_config: IndexConfig 105 :param index_dir: Directory to save index(es) to. If None, indexes will be stored in the model's directory, 106 defaults to None 107 :type index_dir: Path | str | None, optional 108 :param overwrite: Whether to skip or overwrite already existing indexes, defaults to False 109 :type overwrite: bool, optional 110 :param verbose: Toggle verbose output, defaults to False 111 :type verbose: bool, optional 112 """ 113 super().__init__() 114 self.index_config = index_config 115 self.index_dir = index_dir 116 self.overwrite = overwrite 117 self.verbose = verbose 118 self.indexer: Indexer
119
[docs] 120 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None: 121 """Hook to setup the callback. 122 123 :param trainer: PyTorch Lightning Trainer 124 :type trainer: Trainer 125 :param pl_module: LightningIR bi-encoder module used for indexing 126 :type pl_module: BiEncoderModule 127 :param stage: Stage of the trainer, must be "test" 128 :type stage: str 129 :raises ValueError: If the stage is not "test" 130 """ 131 if stage != "test": 132 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 133 self._remove_overwrite_datasets(trainer, pl_module, stage)
134 135 def _get_save_path(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Path: 136 return self._get_index_dir(pl_module, trainer.datamodule.inference_datasets[dataset_idx]) 137 138 def _get_indexer(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Indexer: 139 dataloaders = trainer.test_dataloaders 140 if dataloaders is None: 141 raise ValueError("No test_dataloaders found") 142 143 index_dir = self._get_save_path(trainer, pl_module, dataset_idx) 144 145 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module.config, self.verbose) 146 return indexer 147 148 def _log_to_pg(self, info: Dict[str, Any], trainer: Trainer): 149 pg_callback = trainer.progress_bar_callback 150 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar): 151 return 152 pg = pg_callback.test_progress_bar 153 info = {k: _format_large_number(v) for k, v in info.items()} 154 if pg is not None: 155 pg.set_postfix(info) 156
[docs] 157 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 158 """Hook to test datasets are configured correctly. 159 160 :param trainer: PyTorch Lightning Trainer 161 :type trainer: Trainer 162 :param pl_module: LightningIR BiEncoderModule 163 :type pl_module: BiEncoderModule 164 :raises ValueError: If no test_dataloaders are found 165 :raises ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset` 166 """ 167 dataloaders = trainer.test_dataloaders 168 if dataloaders is None: 169 raise ValueError("No test_dataloaders found") 170 datasets = [dataloader.dataset for dataloader in dataloaders] 171 if not all(isinstance(dataset, DocDataset) for dataset in datasets): 172 raise ValueError("Expected DocDatasets for indexing")
173
[docs] 174 def on_test_batch_start( 175 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 176 ) -> None: 177 """Hook to setup the indexer between datasets. 178 179 :param trainer: PyTorch Lightning Trainer 180 :type trainer: Trainer 181 :param pl_module: LightningIR bi-encoder module 182 :type pl_module: BiEncoderModule 183 :param batch: Batch of input data 184 :type batch: Any 185 :param batch_idx: Index of batch in the current dataset 186 :type batch_idx: int 187 :param dataloader_idx: Index of the dataloader, defaults to 0 188 :type dataloader_idx: int, optional 189 """ 190 if batch_idx == 0: 191 self.indexer = self._get_indexer(trainer, pl_module, dataloader_idx) 192 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
193
[docs] 194 def on_test_batch_end( 195 self, 196 trainer: Trainer, 197 pl_module: BiEncoderModule, 198 outputs: BiEncoderOutput, 199 batch: Any, 200 batch_idx: int, 201 dataloader_idx: int = 0, 202 ) -> None: 203 """Hook to pass encoded documents to the indexer 204 205 :param trainer: PyTorch Lightning Trainer 206 :type trainer: Trainer 207 :param pl_module: LightningIR bi-encoder module 208 :type pl_module: BiEncoderModule 209 :param outputs: Encoded documents 210 :type outputs: BiEncoderOutput 211 :param batch: Batch of input data 212 :type batch: Any 213 :param batch_idx: Index of batch in the current dataset 214 :type batch_idx: int 215 :param dataloader_idx: Index of the dataloader, defaults to 0 216 :type dataloader_idx: int, optional 217 """ 218 batch = self._gather(pl_module, batch) 219 outputs = self._gather(pl_module, outputs) 220 221 if not trainer.is_global_zero: 222 return 223 224 self.indexer.add(batch, outputs) 225 self._log_to_pg( 226 { 227 "num_docs": self.indexer.num_docs, 228 "num_embeddings": self.indexer.num_embeddings, 229 }, 230 trainer, 231 ) 232 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: 233 assert hasattr(self, "indexer") 234 self.indexer.save()
235 236
[docs] 237class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs] 238 def __init__( 239 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False 240 ) -> None: 241 """Callback to write run file of ranked documents to disk. 242 243 :param save_dir: Directory to save run files to. If None, run files will be saved in the models' directory, 244 defaults to None 245 :type save_dir: Path | str | None, optional 246 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used, 247 defaults to None 248 :type run_name: str | None, optional 249 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False 250 :type overwrite: bool, optional 251 """ 252 super().__init__() 253 self.save_dir = Path(save_dir) if save_dir is not None else None 254 self.run_name = run_name 255 self.overwrite = overwrite 256 self.run_dfs: List[pd.DataFrame] = []
257
[docs] 258 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None: 259 """Hook to setup the callback. 260 261 :param trainer: PyTorch Lightning Trainer 262 :type trainer: Trainer 263 :param pl_module: LightningIR module 264 :type pl_module: LightningIRModule 265 :param stage: Stage of the trainer, must be "test" 266 :type stage: str 267 :raises ValueError: If the stage is not "test" 268 :raises ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local) 269 """ 270 if stage != "test": 271 raise ValueError(f"{self.__class__.__name__} can only be used in test stage") 272 if self.save_dir is None: 273 default_save_dir = Path(pl_module.config.name_or_path) 274 if default_save_dir.exists(): 275 self.save_dir = default_save_dir / "runs" 276 print(f"Using default save_dir `{self.save_dir}` to save runs") 277 else: 278 raise ValueError("No save_dir provided and model_name_or_path is not a path") 279 self._remove_overwrite_datasets(trainer, pl_module, stage)
280 281 def _get_save_path(self, trainer: Trainer, pl_module: LightningIRModule, dataset_idx: int) -> Path: 282 datamodule = getattr(trainer, "datamodule", None) 283 if datamodule is None: 284 raise ValueError("No datamodule found") 285 if self.save_dir is None: 286 raise ValueError("No save_dir found; call setup before using this method") 287 dataset = datamodule.inference_datasets[dataset_idx] 288 if self.run_name is not None: 289 run_file = self.run_name 290 elif isinstance(dataset, QueryDataset): 291 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 292 elif isinstance(dataset, RunDataset): 293 if dataset.run_path is None: 294 run_file = f"{dataset.dataset_id.replace('/', '-')}.run" 295 else: 296 run_file = f"{dataset.run_path.name.split('.')[0]}.run" 297 run_file_path = self.save_dir / run_file 298 return run_file_path 299 300 def _rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]: 301 scores = output.scores 302 if scores is None: 303 raise ValueError("Expected output to have scores") 304 doc_ids = batch.doc_ids 305 if doc_ids is None: 306 raise ValueError("Expected batch to have doc_ids") 307 scores = scores.view(-1) 308 num_docs = [len(_doc_ids) for _doc_ids in doc_ids] 309 doc_ids = list(itertools.chain.from_iterable(doc_ids)) 310 if scores.shape[0] != len(doc_ids): 311 raise ValueError("scores and doc_ids must have the same length") 312 return scores, doc_ids, num_docs 313 314 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int): 315 if not trainer.is_global_zero or not self.run_dfs: 316 return 317 run_file_path = self._get_save_path(trainer, pl_module, dataloader_idx) 318 run_file_path.parent.mkdir(parents=True, exist_ok=True) 319 run_df = pd.concat(self.run_dfs, ignore_index=True) 320 run_df.to_csv(run_file_path, header=False, index=False, sep="\t") 321
[docs] 322 def on_test_batch_end( 323 self, 324 trainer: Trainer, 325 pl_module: LightningIRModule, 326 outputs: LightningIROutput, 327 batch: Any, 328 batch_idx: int, 329 dataloader_idx: int = 0, 330 ) -> None: 331 """Hook to aggregate and write ranking to file. 332 333 :param trainer: PyTorch Lightning Trainer 334 :type trainer: Trainer 335 :param pl_module: LightningIR Module 336 :type pl_module: LightningIRModule 337 :param outputs: Scored query documents pairs 338 :type outputs: LightningIROutput 339 :param batch: Batch of input data 340 :type batch: Any 341 :param batch_idx: Index of batch in the current dataset 342 :type batch_idx: int 343 :param dataloader_idx: Index of the dataloader, defaults to 0 344 :type dataloader_idx: int, optional 345 """ 346 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) 347 batch = self._gather(pl_module, batch) 348 outputs = self._gather(pl_module, outputs) 349 if not trainer.is_global_zero: 350 return 351 352 query_ids = batch.query_ids 353 if query_ids is None: 354 raise ValueError("Expected batch to have query_ids") 355 scores, doc_ids, num_docs = self._rank(batch, outputs) 356 scores = scores.float().cpu().numpy() 357 358 query_ids = list( 359 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs)) 360 ) 361 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"]) 362 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False]) 363 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int) 364 run_df["q0"] = 0 365 run_df["system"] = pl_module.model.__class__.__name__ 366 run_df = run_df[RUN_HEADER] 367 368 self.run_dfs.append(run_df) 369 370 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: 371 self._write_run_dfs(trainer, pl_module, dataloader_idx) 372 self.run_dfs = []
373 374
[docs] 375class SearchCallback(RankCallback, _IndexDirMixin):
[docs] 376 def __init__( 377 self, 378 search_config: SearchConfig, 379 index_dir: Path | str | None = None, 380 save_dir: Path | str | None = None, 381 run_name: str | None = None, 382 overwrite: bool = False, 383 use_gpu: bool = True, 384 ) -> None: 385 """Callback to which uses index to retrieve documents efficiently. 386 387 :param search_config: Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher` 388 :type search_config: SearchConfig 389 :param index_dir: Directory where indexes are stored, defaults to None 390 :type index_dir: Path | str | None, optional 391 :param save_dir: Directory to save run files to. If None, run files are saved in the model's directory, 392 defaults to None 393 :type save_dir: Path | str | None, optional 394 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used, 395 defaults to None 396 :type run_name: str | None, optional 397 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False 398 :type overwrite: bool, optional 399 :param use_gpu: Toggle to use gpu for retrieval, defaults to True 400 :type use_gpu: bool, optional 401 """ 402 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite) 403 self.search_config = search_config 404 self.index_dir = index_dir 405 self.overwrite = overwrite 406 self.use_gpu = use_gpu 407 self.searcher: Searcher
408 409 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher: 410 dataloaders = trainer.test_dataloaders 411 if dataloaders is None: 412 raise ValueError("No test_dataloaders found") 413 dataset = dataloaders[dataset_idx].dataset 414 415 index_dir = self._get_index_dir(pl_module, dataset) 416 if getattr(self, "searcher", None) is not None and self.searcher.index_dir == index_dir: 417 return self.searcher 418 419 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu) 420 return searcher 421 422 def _rank( 423 self, batch: SearchBatch | RankBatch, output: LightningIROutput 424 ) -> Tuple[torch.Tensor, List[str], List[int]]: 425 if batch.doc_ids is None: 426 raise ValueError("BiEncoderModule did not return doc_ids when searching") 427 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids] 428 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels) 429 return super()._rank(batch, output) 430
[docs] 431 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None: 432 """Hook to validate datasets 433 434 :param trainer: PyTorch Lightning Trainer 435 :type trainer: Trainer 436 :param pl_module: LightningIR BiEncoderModule 437 :type pl_module: BiEncoderModule 438 :raises ValueError: If no test_dataloaders are found 439 :raises ValueError: If not all datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset` 440 """ 441 dataloaders = trainer.test_dataloaders 442 if dataloaders is None: 443 raise ValueError("No test_dataloaders found") 444 datasets = [dataloader.dataset for dataloader in dataloaders] 445 if not all(isinstance(dataset, QueryDataset) for dataset in datasets): 446 raise ValueError("Expected QueryDatasets for indexing")
447
[docs] 448 def on_test_batch_start( 449 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0 450 ) -> None: 451 """Hook to initialize searcher for new datasets. 452 453 :param trainer: PyTorch Lightning Trainer 454 :type trainer: Trainer 455 :param pl_module: LightningIR BiEncoderModule 456 :type pl_module: BiEncoderModule 457 :param batch: Batch of input data 458 :type batch: Any 459 :param batch_idx: Index of batch in dataset 460 :type batch_idx: int 461 :param dataloader_idx: Index of the dataloader, defaults to 0 462 :type dataloader_idx: int, optional 463 """ 464 if batch_idx == 0: 465 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx) 466 pl_module.searcher = self.searcher 467 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
468 469
[docs] 470class ReRankCallback(RankCallback): 471 pass
472 473
[docs] 474class RegisterLocalDatasetCallback(Callback): 475
[docs] 476 def __init__( 477 self, 478 dataset_id: str, 479 docs: str | None = None, 480 queries: str | None = None, 481 qrels: str | None = None, 482 docpairs: str | None = None, 483 scoreddocs: str | None = None, 484 qrels_defs: Dict[int, str] | None = None, 485 ): 486 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using 487 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported: 488 489 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries 490 - ``.tsv`` or ``.qrels`` for qrels 491 - ``.tsv`` for training n-tuples 492 - ``.tsv`` or ``.run`` for scored documents / run files 493 494 :param dataset_id: Dataset id 495 :type dataset_id: str 496 :param docs: Path to documents file or valid ir_datasets id from which documents should be taken, 497 defaults to None 498 :type docs: str | None, optional 499 :param queries: Path to queries file or valid ir_datastes id from which queries should be taken, 500 defaults to None 501 :type queries: str | None, optional 502 :param qrels: Path to qrels file or valid ir_datasets id from which qrels will be taken, defaults to None 503 :type qrels: str | None, optional 504 :param docpairs: Path to training n-tuple file or valid ir_datasets id from which training tuples will be taken, 505 defaults to None 506 :type docpairs: str | None, optional 507 :param scoreddocs: Path to run file or valid ir_datasets id from which scored documents will be taken, 508 defaults to None 509 :type scoreddocs: str | None, optional 510 :param qrels_defs: Optional dictionary describing the relevance levels of the qrels, defaults to None 511 :type qrels_defs: Dict[int, str] | None, optional 512 """ 513 super().__init__() 514 self.dataset_id = dataset_id 515 self.docs = docs 516 self.queries = queries 517 self.qrels = qrels 518 self.docpairs = docpairs 519 self.scoreddocs = scoreddocs 520 self.qrels_defs = qrels_defs
521
[docs] 522 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 523 """Hook that registers dataset. 524 525 :param trainer: PyTorch Lightning Trainer 526 :type trainer: Trainer 527 :param pl_module: PyTorch Lightning LightningModule 528 :type pl_module: LightningModule 529 :param stage: Stage of the trainer 530 :type stage: str 531 """ 532 _register_local_dataset( 533 self.dataset_id, self.docs, self.queries, self.qrels, self.docpairs, self.scoreddocs, self.qrels_defs 534 )