1"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""
2
3from __future__ import annotations
4
5import itertools
6from dataclasses import is_dataclass
7from pathlib import Path
8from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar
9
10import pandas as pd
11import torch
12from lightning import LightningModule, Trainer
13from lightning.pytorch.callbacks import Callback, TQDMProgressBar
14
15from ..data import RankBatch, SearchBatch
16from ..data.dataset import RUN_HEADER, DocDataset, QueryDataset, RunDataset
17from ..data.ir_datasets_utils import _register_local_dataset
18from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher
19
20if TYPE_CHECKING:
21 from ..base import LightningIRModule, LightningIROutput
22 from ..bi_encoder import BiEncoderModule, BiEncoderOutput
23
24T = TypeVar("T")
25
26
27def _format_large_number(number: float) -> str:
28 suffixes = ["", "K", "M", "B", "T"]
29 suffix_index = 0
30
31 while number >= 1000 and suffix_index < len(suffixes) - 1:
32 number /= 1000.0
33 suffix_index += 1
34
35 formatted_number = "{:.2f}".format(number)
36
37 suffix = suffixes[suffix_index]
38 if suffix:
39 formatted_number += f" {suffix}"
40 return formatted_number
41
42
43class _GatherMixin:
44 """Mixin to gather dataclasses across all processes"""
45
46 def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
47 if is_dataclass(dataclass):
48 return dataclass.__class__(
49 **{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
50 )
51 return pl_module.all_gather(dataclass)
52
53
54class _IndexDirMixin:
55 """Mixin to get index_dir"""
56
57 index_dir: Path | str | None
58
59 def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
60 index_dir = self.index_dir
61 if index_dir is None:
62 default_index_dir = Path(pl_module.config.name_or_path)
63 if default_index_dir.exists():
64 index_dir = default_index_dir / "indexes"
65 else:
66 raise ValueError("No index_dir provided and model_name_or_path is not a path")
67 index_dir = Path(index_dir)
68 index_dir = index_dir / dataset.docs_dataset_id
69 return index_dir
70
71
72class _OverwriteMixin:
73 """Mixin to skip datasets (for indexing or searching) if they already exist"""
74
75 _get_save_path: Callable[[Trainer, LightningModule, int], Path]
76
77 def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
78 overwrite = getattr(self, "overwrite", False)
79 if not overwrite:
80 datasets = list(trainer.datamodule.inference_datasets)
81 remove_datasets = []
82 for dataset_idx in range(len(datasets)):
83 save_path = self._get_save_path(trainer, pl_module, dataset_idx)
84 if save_path.exists():
85 remove_datasets.append(dataset_idx)
86 trainer.print(
87 f"`{save_path}` already exists. Skipping this dataset. Set overwrite=True to overwrite"
88 )
89 for dataset_idx in remove_datasets[::-1]:
90 del trainer.datamodule.inference_datasets[dataset_idx]
91
92
[docs]
93class IndexCallback(Callback, _GatherMixin, _IndexDirMixin, _OverwriteMixin):
[docs]
94 def __init__(
95 self,
96 index_config: IndexConfig,
97 index_dir: Path | str | None = None,
98 overwrite: bool = False,
99 verbose: bool = False,
100 ) -> None:
101 """Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`.
102
103 :param index_config: Configuration for the indexer
104 :type index_config: IndexConfig
105 :param index_dir: Directory to save index(es) to. If None, indexes will be stored in the model's directory,
106 defaults to None
107 :type index_dir: Path | str | None, optional
108 :param overwrite: Whether to skip or overwrite already existing indexes, defaults to False
109 :type overwrite: bool, optional
110 :param verbose: Toggle verbose output, defaults to False
111 :type verbose: bool, optional
112 """
113 super().__init__()
114 self.index_config = index_config
115 self.index_dir = index_dir
116 self.overwrite = overwrite
117 self.verbose = verbose
118 self.indexer: Indexer
119
[docs]
120 def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
121 """Hook to setup the callback.
122
123 :param trainer: PyTorch Lightning Trainer
124 :type trainer: Trainer
125 :param pl_module: LightningIR bi-encoder module used for indexing
126 :type pl_module: BiEncoderModule
127 :param stage: Stage of the trainer, must be "test"
128 :type stage: str
129 :raises ValueError: If the stage is not "test"
130 """
131 if stage != "test":
132 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
133 self._remove_overwrite_datasets(trainer, pl_module, stage)
134
135 def _get_save_path(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Path:
136 return self._get_index_dir(pl_module, trainer.datamodule.inference_datasets[dataset_idx])
137
138 def _get_indexer(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Indexer:
139 dataloaders = trainer.test_dataloaders
140 if dataloaders is None:
141 raise ValueError("No test_dataloaders found")
142
143 index_dir = self._get_save_path(trainer, pl_module, dataset_idx)
144
145 indexer = self.index_config.indexer_class(index_dir, self.index_config, pl_module.config, self.verbose)
146 return indexer
147
148 def _log_to_pg(self, info: Dict[str, Any], trainer: Trainer):
149 pg_callback = trainer.progress_bar_callback
150 if pg_callback is None or not isinstance(pg_callback, TQDMProgressBar):
151 return
152 pg = pg_callback.test_progress_bar
153 info = {k: _format_large_number(v) for k, v in info.items()}
154 if pg is not None:
155 pg.set_postfix(info)
156
[docs]
157 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
158 """Hook to test datasets are configured correctly.
159
160 :param trainer: PyTorch Lightning Trainer
161 :type trainer: Trainer
162 :param pl_module: LightningIR BiEncoderModule
163 :type pl_module: BiEncoderModule
164 :raises ValueError: If no test_dataloaders are found
165 :raises ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset`
166 """
167 dataloaders = trainer.test_dataloaders
168 if dataloaders is None:
169 raise ValueError("No test_dataloaders found")
170 datasets = [dataloader.dataset for dataloader in dataloaders]
171 if not all(isinstance(dataset, DocDataset) for dataset in datasets):
172 raise ValueError("Expected DocDatasets for indexing")
173
[docs]
174 def on_test_batch_start(
175 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
176 ) -> None:
177 """Hook to setup the indexer between datasets.
178
179 :param trainer: PyTorch Lightning Trainer
180 :type trainer: Trainer
181 :param pl_module: LightningIR bi-encoder module
182 :type pl_module: BiEncoderModule
183 :param batch: Batch of input data
184 :type batch: Any
185 :param batch_idx: Index of batch in the current dataset
186 :type batch_idx: int
187 :param dataloader_idx: Index of the dataloader, defaults to 0
188 :type dataloader_idx: int, optional
189 """
190 if batch_idx == 0:
191 self.indexer = self._get_indexer(trainer, pl_module, dataloader_idx)
192 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
193
[docs]
194 def on_test_batch_end(
195 self,
196 trainer: Trainer,
197 pl_module: BiEncoderModule,
198 outputs: BiEncoderOutput,
199 batch: Any,
200 batch_idx: int,
201 dataloader_idx: int = 0,
202 ) -> None:
203 """Hook to pass encoded documents to the indexer
204
205 :param trainer: PyTorch Lightning Trainer
206 :type trainer: Trainer
207 :param pl_module: LightningIR bi-encoder module
208 :type pl_module: BiEncoderModule
209 :param outputs: Encoded documents
210 :type outputs: BiEncoderOutput
211 :param batch: Batch of input data
212 :type batch: Any
213 :param batch_idx: Index of batch in the current dataset
214 :type batch_idx: int
215 :param dataloader_idx: Index of the dataloader, defaults to 0
216 :type dataloader_idx: int, optional
217 """
218 batch = self._gather(pl_module, batch)
219 outputs = self._gather(pl_module, outputs)
220
221 if not trainer.is_global_zero:
222 return
223
224 self.indexer.add(batch, outputs)
225 self._log_to_pg(
226 {
227 "num_docs": self.indexer.num_docs,
228 "num_embeddings": self.indexer.num_embeddings,
229 },
230 trainer,
231 )
232 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
233 assert hasattr(self, "indexer")
234 self.indexer.save()
235
236
[docs]
237class RankCallback(Callback, _GatherMixin, _OverwriteMixin):
[docs]
238 def __init__(
239 self, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False
240 ) -> None:
241 """Callback to write run file of ranked documents to disk.
242
243 :param save_dir: Directory to save run files to. If None, run files will be saved in the models' directory,
244 defaults to None
245 :type save_dir: Path | str | None, optional
246 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used,
247 defaults to None
248 :type run_name: str | None, optional
249 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False
250 :type overwrite: bool, optional
251 """
252 super().__init__()
253 self.save_dir = Path(save_dir) if save_dir is not None else None
254 self.run_name = run_name
255 self.overwrite = overwrite
256 self.run_dfs: List[pd.DataFrame] = []
257
[docs]
258 def setup(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
259 """Hook to setup the callback.
260
261 :param trainer: PyTorch Lightning Trainer
262 :type trainer: Trainer
263 :param pl_module: LightningIR module
264 :type pl_module: LightningIRModule
265 :param stage: Stage of the trainer, must be "test"
266 :type stage: str
267 :raises ValueError: If the stage is not "test"
268 :raises ValueError: If no save_dir is provided and model_name_or_path is not a path (the model is not local)
269 """
270 if stage != "test":
271 raise ValueError(f"{self.__class__.__name__} can only be used in test stage")
272 if self.save_dir is None:
273 default_save_dir = Path(pl_module.config.name_or_path)
274 if default_save_dir.exists():
275 self.save_dir = default_save_dir / "runs"
276 print(f"Using default save_dir `{self.save_dir}` to save runs")
277 else:
278 raise ValueError("No save_dir provided and model_name_or_path is not a path")
279 self._remove_overwrite_datasets(trainer, pl_module, stage)
280
281 def _get_save_path(self, trainer: Trainer, pl_module: LightningIRModule, dataset_idx: int) -> Path:
282 datamodule = getattr(trainer, "datamodule", None)
283 if datamodule is None:
284 raise ValueError("No datamodule found")
285 if self.save_dir is None:
286 raise ValueError("No save_dir found; call setup before using this method")
287 dataset = datamodule.inference_datasets[dataset_idx]
288 if self.run_name is not None:
289 run_file = self.run_name
290 elif isinstance(dataset, QueryDataset):
291 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
292 elif isinstance(dataset, RunDataset):
293 if dataset.run_path is None:
294 run_file = f"{dataset.dataset_id.replace('/', '-')}.run"
295 else:
296 run_file = f"{dataset.run_path.name.split('.')[0]}.run"
297 run_file_path = self.save_dir / run_file
298 return run_file_path
299
300 def _rank(self, batch: RankBatch, output: LightningIROutput) -> Tuple[torch.Tensor, List[str], List[int]]:
301 scores = output.scores
302 if scores is None:
303 raise ValueError("Expected output to have scores")
304 doc_ids = batch.doc_ids
305 if doc_ids is None:
306 raise ValueError("Expected batch to have doc_ids")
307 scores = scores.view(-1)
308 num_docs = [len(_doc_ids) for _doc_ids in doc_ids]
309 doc_ids = list(itertools.chain.from_iterable(doc_ids))
310 if scores.shape[0] != len(doc_ids):
311 raise ValueError("scores and doc_ids must have the same length")
312 return scores, doc_ids, num_docs
313
314 def _write_run_dfs(self, trainer: Trainer, pl_module: LightningIRModule, dataloader_idx: int):
315 if not trainer.is_global_zero or not self.run_dfs:
316 return
317 run_file_path = self._get_save_path(trainer, pl_module, dataloader_idx)
318 run_file_path.parent.mkdir(parents=True, exist_ok=True)
319 run_df = pd.concat(self.run_dfs, ignore_index=True)
320 run_df.to_csv(run_file_path, header=False, index=False, sep="\t")
321
[docs]
322 def on_test_batch_end(
323 self,
324 trainer: Trainer,
325 pl_module: LightningIRModule,
326 outputs: LightningIROutput,
327 batch: Any,
328 batch_idx: int,
329 dataloader_idx: int = 0,
330 ) -> None:
331 """Hook to aggregate and write ranking to file.
332
333 :param trainer: PyTorch Lightning Trainer
334 :type trainer: Trainer
335 :param pl_module: LightningIR Module
336 :type pl_module: LightningIRModule
337 :param outputs: Scored query documents pairs
338 :type outputs: LightningIROutput
339 :param batch: Batch of input data
340 :type batch: Any
341 :param batch_idx: Index of batch in the current dataset
342 :type batch_idx: int
343 :param dataloader_idx: Index of the dataloader, defaults to 0
344 :type dataloader_idx: int, optional
345 """
346 super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
347 batch = self._gather(pl_module, batch)
348 outputs = self._gather(pl_module, outputs)
349 if not trainer.is_global_zero:
350 return
351
352 query_ids = batch.query_ids
353 if query_ids is None:
354 raise ValueError("Expected batch to have query_ids")
355 scores, doc_ids, num_docs = self._rank(batch, outputs)
356 scores = scores.float().cpu().numpy()
357
358 query_ids = list(
359 itertools.chain.from_iterable(itertools.repeat(query_id, num) for query_id, num in zip(query_ids, num_docs))
360 )
361 run_df = pd.DataFrame(zip(query_ids, doc_ids, scores), columns=["query_id", "doc_id", "score"])
362 run_df = run_df.sort_values(["query_id", "score"], ascending=[True, False])
363 run_df["rank"] = run_df.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int)
364 run_df["q0"] = 0
365 run_df["system"] = pl_module.model.__class__.__name__
366 run_df = run_df[RUN_HEADER]
367
368 self.run_dfs.append(run_df)
369
370 if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
371 self._write_run_dfs(trainer, pl_module, dataloader_idx)
372 self.run_dfs = []
373
374
[docs]
375class SearchCallback(RankCallback, _IndexDirMixin):
[docs]
376 def __init__(
377 self,
378 search_config: SearchConfig,
379 index_dir: Path | str | None = None,
380 save_dir: Path | str | None = None,
381 run_name: str | None = None,
382 overwrite: bool = False,
383 use_gpu: bool = True,
384 ) -> None:
385 """Callback to which uses index to retrieve documents efficiently.
386
387 :param search_config: Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher`
388 :type search_config: SearchConfig
389 :param index_dir: Directory where indexes are stored, defaults to None
390 :type index_dir: Path | str | None, optional
391 :param save_dir: Directory to save run files to. If None, run files are saved in the model's directory,
392 defaults to None
393 :type save_dir: Path | str | None, optional
394 :param run_name: Name of the run file. If None, the dataset's dataset_id or file name will be used,
395 defaults to None
396 :type run_name: str | None, optional
397 :param overwrite: Whether to skip or overwrite already existing run files, defaults to False
398 :type overwrite: bool, optional
399 :param use_gpu: Toggle to use gpu for retrieval, defaults to True
400 :type use_gpu: bool, optional
401 """
402 super().__init__(save_dir=save_dir, run_name=run_name, overwrite=overwrite)
403 self.search_config = search_config
404 self.index_dir = index_dir
405 self.overwrite = overwrite
406 self.use_gpu = use_gpu
407 self.searcher: Searcher
408
409 def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Searcher:
410 dataloaders = trainer.test_dataloaders
411 if dataloaders is None:
412 raise ValueError("No test_dataloaders found")
413 dataset = dataloaders[dataset_idx].dataset
414
415 index_dir = self._get_index_dir(pl_module, dataset)
416 if getattr(self, "searcher", None) is not None and self.searcher.index_dir == index_dir:
417 return self.searcher
418
419 searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu)
420 return searcher
421
422 def _rank(
423 self, batch: SearchBatch | RankBatch, output: LightningIROutput
424 ) -> Tuple[torch.Tensor, List[str], List[int]]:
425 if batch.doc_ids is None:
426 raise ValueError("BiEncoderModule did not return doc_ids when searching")
427 dummy_docs = [[""] * len(ids) for ids in batch.doc_ids]
428 batch = RankBatch(batch.queries, dummy_docs, batch.query_ids, batch.doc_ids, batch.qrels)
429 return super()._rank(batch, output)
430
[docs]
431 def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
432 """Hook to validate datasets
433
434 :param trainer: PyTorch Lightning Trainer
435 :type trainer: Trainer
436 :param pl_module: LightningIR BiEncoderModule
437 :type pl_module: BiEncoderModule
438 :raises ValueError: If no test_dataloaders are found
439 :raises ValueError: If not all datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset`
440 """
441 dataloaders = trainer.test_dataloaders
442 if dataloaders is None:
443 raise ValueError("No test_dataloaders found")
444 datasets = [dataloader.dataset for dataloader in dataloaders]
445 if not all(isinstance(dataset, QueryDataset) for dataset in datasets):
446 raise ValueError("Expected QueryDatasets for indexing")
447
[docs]
448 def on_test_batch_start(
449 self, trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
450 ) -> None:
451 """Hook to initialize searcher for new datasets.
452
453 :param trainer: PyTorch Lightning Trainer
454 :type trainer: Trainer
455 :param pl_module: LightningIR BiEncoderModule
456 :type pl_module: BiEncoderModule
457 :param batch: Batch of input data
458 :type batch: Any
459 :param batch_idx: Index of batch in dataset
460 :type batch_idx: int
461 :param dataloader_idx: Index of the dataloader, defaults to 0
462 :type dataloader_idx: int, optional
463 """
464 if batch_idx == 0:
465 self.searcher = self._get_searcher(trainer, pl_module, dataloader_idx)
466 pl_module.searcher = self.searcher
467 super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
468
469
[docs]
470class ReRankCallback(RankCallback):
471 pass
472
473
[docs]
474class RegisterLocalDatasetCallback(Callback):
475
[docs]
476 def __init__(
477 self,
478 dataset_id: str,
479 docs: str | None = None,
480 queries: str | None = None,
481 qrels: str | None = None,
482 docpairs: str | None = None,
483 scoreddocs: str | None = None,
484 qrels_defs: Dict[int, str] | None = None,
485 ):
486 """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using
487 ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported:
488
489 - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries
490 - ``.tsv`` or ``.qrels`` for qrels
491 - ``.tsv`` for training n-tuples
492 - ``.tsv`` or ``.run`` for scored documents / run files
493
494 :param dataset_id: Dataset id
495 :type dataset_id: str
496 :param docs: Path to documents file or valid ir_datasets id from which documents should be taken,
497 defaults to None
498 :type docs: str | None, optional
499 :param queries: Path to queries file or valid ir_datastes id from which queries should be taken,
500 defaults to None
501 :type queries: str | None, optional
502 :param qrels: Path to qrels file or valid ir_datasets id from which qrels will be taken, defaults to None
503 :type qrels: str | None, optional
504 :param docpairs: Path to training n-tuple file or valid ir_datasets id from which training tuples will be taken,
505 defaults to None
506 :type docpairs: str | None, optional
507 :param scoreddocs: Path to run file or valid ir_datasets id from which scored documents will be taken,
508 defaults to None
509 :type scoreddocs: str | None, optional
510 :param qrels_defs: Optional dictionary describing the relevance levels of the qrels, defaults to None
511 :type qrels_defs: Dict[int, str] | None, optional
512 """
513 super().__init__()
514 self.dataset_id = dataset_id
515 self.docs = docs
516 self.queries = queries
517 self.qrels = qrels
518 self.docpairs = docpairs
519 self.scoreddocs = scoreddocs
520 self.qrels_defs = qrels_defs
521
[docs]
522 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
523 """Hook that registers dataset.
524
525 :param trainer: PyTorch Lightning Trainer
526 :type trainer: Trainer
527 :param pl_module: PyTorch Lightning LightningModule
528 :type pl_module: LightningModule
529 :param stage: Stage of the trainer
530 :type stage: str
531 """
532 _register_local_dataset(
533 self.dataset_id, self.docs, self.queries, self.qrels, self.docpairs, self.scoreddocs, self.qrels_defs
534 )