1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from pathlib import Path
5from typing import TYPE_CHECKING, List, Sequence, Tuple, Type
6
7import torch
8
9from ...bi_encoder.model import BiEncoderEmbedding
10
11if TYPE_CHECKING:
12 from ...bi_encoder import BiEncoderModule, BiEncoderOutput
13
14
[docs]
15class Searcher(ABC):
[docs]
16 def __init__(
17 self, index_dir: Path | str, search_config: SearchConfig, module: BiEncoderModule, use_gpu: bool = True
18 ) -> None:
19 super().__init__()
20 self.index_dir = Path(index_dir)
21 self.search_config = search_config
22 self.module = module
23 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
24
25 self.doc_ids = (self.index_dir / "doc_ids.txt").read_text().split()
26 self.doc_lengths = torch.load(self.index_dir / "doc_lengths.pt")
27
28 self.to_gpu()
29
30 self.num_docs = len(self.doc_ids)
31 self.cumulative_doc_lengths = torch.cumsum(self.doc_lengths, dim=0)
32
33 if self.doc_lengths.shape[0] != self.num_docs or self.doc_lengths.sum() != self.num_embeddings:
34 raise ValueError("doc_lengths do not match index")
35
36 def to_gpu(self) -> None:
37 self.doc_lengths = self.doc_lengths.to(self.device)
38
39 @property
40 @abstractmethod
41 def num_embeddings(self) -> int: ...
42
43 @abstractmethod
44 def _search(
45 self, query_embeddings: BiEncoderEmbedding
46 ) -> Tuple[torch.Tensor, torch.Tensor | None, List[int] | None]: ...
47
48 def _filter_and_sort(
49 self,
50 doc_scores: torch.Tensor,
51 doc_idcs: torch.Tensor | None,
52 num_docs: Sequence[int] | None,
53 ) -> Tuple[torch.Tensor, List[str], List[int]]:
54 if (doc_idcs is None) != (num_docs is None):
55 raise ValueError("doc_ids and num_docs must be both None or not None")
56 if doc_idcs is None and num_docs is None:
57 # assume we have searched the whole index
58 k = min(self.search_config.k, doc_scores.shape[0])
59 values, idcs = torch.topk(doc_scores.view(-1, self.num_docs), k)
60 num_queries = values.shape[0]
61 values = values.view(-1)
62 idcs = idcs.view(-1)
63 doc_ids = [self.doc_ids[doc_idx] for doc_idx in idcs.cpu()]
64 return values, doc_ids, [k] * num_queries
65
66 assert doc_idcs is not None and num_docs is not None
67 per_query_doc_scores = torch.split(doc_scores, num_docs)
68 per_query_doc_idcs = torch.split(doc_idcs, num_docs)
69 new_num_docs = []
70 _doc_scores = []
71 doc_ids = []
72 for query_idx, scores in enumerate(per_query_doc_scores):
73 k = min(self.search_config.k, scores.shape[0])
74 values, idcs = torch.topk(scores, k)
75 _doc_scores.append(values)
76 doc_ids.extend([self.doc_ids[doc_idx] for doc_idx in per_query_doc_idcs[query_idx][idcs].cpu()])
77 new_num_docs.append(k)
78 doc_scores = torch.cat(_doc_scores)
79 return doc_scores, doc_ids, new_num_docs
80
81 def search(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, List[str], List[int]]:
82 query_embeddings = output.query_embeddings
83 if query_embeddings is None:
84 raise ValueError("Expected query_embeddings in BiEncoderOutput")
85 doc_scores, doc_idcs, num_docs = self._search(query_embeddings)
86 doc_scores, doc_ids, num_docs = self._filter_and_sort(doc_scores, doc_idcs, num_docs)
87
88 return doc_scores, doc_ids, num_docs
89
90
[docs]
91class SearchConfig:
92 search_class: Type[Searcher] = Searcher
93
[docs]
94 def __init__(self, k: int = 10) -> None:
95 self.k = k