Source code for lightning_ir.retrieve.base.searcher

 1from __future__ import annotations
 2
 3from abc import ABC, abstractmethod
 4from pathlib import Path
 5from typing import TYPE_CHECKING, List, Sequence, Tuple, Type
 6
 7import torch
 8
 9from ...bi_encoder.model import BiEncoderEmbedding
10
11if TYPE_CHECKING:
12    from ...bi_encoder import BiEncoderModule, BiEncoderOutput
13
14
[docs] 15class Searcher(ABC):
[docs] 16 def __init__( 17 self, index_dir: Path | str, search_config: SearchConfig, module: BiEncoderModule, use_gpu: bool = True 18 ) -> None: 19 super().__init__() 20 self.index_dir = Path(index_dir) 21 self.search_config = search_config 22 self.module = module 23 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu") 24 25 self.doc_ids = (self.index_dir / "doc_ids.txt").read_text().split() 26 self.doc_lengths = torch.load(self.index_dir / "doc_lengths.pt") 27 28 self.to_gpu() 29 30 self.num_docs = len(self.doc_ids) 31 self.cumulative_doc_lengths = torch.cumsum(self.doc_lengths, dim=0) 32 33 if self.doc_lengths.shape[0] != self.num_docs or self.doc_lengths.sum() != self.num_embeddings: 34 raise ValueError("doc_lengths do not match index")
35 36 def to_gpu(self) -> None: 37 self.doc_lengths = self.doc_lengths.to(self.device) 38 39 @property 40 @abstractmethod 41 def num_embeddings(self) -> int: ... 42 43 @abstractmethod 44 def _search( 45 self, query_embeddings: BiEncoderEmbedding 46 ) -> Tuple[torch.Tensor, torch.Tensor | None, List[int] | None]: ... 47 48 def _filter_and_sort( 49 self, 50 doc_scores: torch.Tensor, 51 doc_idcs: torch.Tensor | None, 52 num_docs: Sequence[int] | None, 53 ) -> Tuple[torch.Tensor, List[str], List[int]]: 54 if (doc_idcs is None) != (num_docs is None): 55 raise ValueError("doc_ids and num_docs must be both None or not None") 56 if doc_idcs is None and num_docs is None: 57 # assume we have searched the whole index 58 k = min(self.search_config.k, doc_scores.shape[0]) 59 values, idcs = torch.topk(doc_scores.view(-1, self.num_docs), k) 60 num_queries = values.shape[0] 61 values = values.view(-1) 62 idcs = idcs.view(-1) 63 doc_ids = [self.doc_ids[doc_idx] for doc_idx in idcs.cpu()] 64 return values, doc_ids, [k] * num_queries 65 66 assert doc_idcs is not None and num_docs is not None 67 per_query_doc_scores = torch.split(doc_scores, num_docs) 68 per_query_doc_idcs = torch.split(doc_idcs, num_docs) 69 new_num_docs = [] 70 _doc_scores = [] 71 doc_ids = [] 72 for query_idx, scores in enumerate(per_query_doc_scores): 73 k = min(self.search_config.k, scores.shape[0]) 74 values, idcs = torch.topk(scores, k) 75 _doc_scores.append(values) 76 doc_ids.extend([self.doc_ids[doc_idx] for doc_idx in per_query_doc_idcs[query_idx][idcs].cpu()]) 77 new_num_docs.append(k) 78 doc_scores = torch.cat(_doc_scores) 79 return doc_scores, doc_ids, new_num_docs 80 81 def search(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, List[str], List[int]]: 82 query_embeddings = output.query_embeddings 83 if query_embeddings is None: 84 raise ValueError("Expected query_embeddings in BiEncoderOutput") 85 doc_scores, doc_idcs, num_docs = self._search(query_embeddings) 86 doc_scores, doc_ids, num_docs = self._filter_and_sort(doc_scores, doc_idcs, num_docs) 87 88 return doc_scores, doc_ids, num_docs
89 90
[docs] 91class SearchConfig: 92 search_class: Type[Searcher] = Searcher 93
[docs] 94 def __init__(self, k: int = 10) -> None: 95 self.k = k