1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, List, Literal, Tuple
5
6import torch
7
8from ...bi_encoder.model import BiEncoderEmbedding
9from ..base import SearchConfig, Searcher
10
11if TYPE_CHECKING:
12 from ...bi_encoder import BiEncoderModule
13
14
[docs]
15class FaissSearcher(Searcher):
[docs]
16 def __init__(
17 self,
18 index_dir: Path | str,
19 search_config: FaissSearchConfig,
20 module: BiEncoderModule,
21 use_gpu: bool = False,
22 ) -> None:
23 import faiss
24
25 self.search_config: FaissSearchConfig
26 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss"))
27 if use_gpu and hasattr(faiss, "index_cpu_to_all_gpus"):
28 self.index = faiss.index_cpu_to_all_gpus(self.index)
29 ivf_index = None
30 try:
31 ivf_index = faiss.extract_index_ivf(self.index)
32 except RuntimeError:
33 pass
34 if ivf_index is not None:
35 ivf_index.nprobe = search_config.n_probe
36 quantizer = getattr(ivf_index, "quantizer", None)
37 if quantizer is not None:
38 downcasted_quantizer = faiss.downcast_index(quantizer)
39 hnsw = getattr(downcasted_quantizer, "hnsw", None)
40 if hnsw is not None:
41 hnsw.efSearch = search_config.ef_search
42 super().__init__(index_dir, search_config, module, use_gpu)
43
44 @property
45 def num_embeddings(self) -> int:
46 return self.index.ntotal
47
48 @property
49 def doc_is_single_vector(self) -> bool:
50 return self.num_docs == self.num_embeddings
51
52 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
53 query_embeddings = query_embeddings.to(self.device)
54 candidate_scores, candidate_doc_idcs = self.candidate_retrieval(query_embeddings)
55 query_lengths = query_embeddings.scoring_mask.sum(-1)
56 if self.search_config.imputation_strategy == "gather":
57 doc_embeddings, doc_idcs, num_docs = self.gather_imputation(candidate_doc_idcs, query_lengths)
58 doc_scores = self.module.model.score(query_embeddings, doc_embeddings, num_docs)
59 else:
60 doc_scores, doc_idcs, num_docs = self.intra_ranking_imputation(
61 candidate_scores, candidate_doc_idcs, query_lengths
62 )
63 return doc_scores, doc_idcs, num_docs
64
65 def candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor]:
66 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
67 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k)
68 candidate_scores = torch.from_numpy(candidate_scores)
69 candidate_idcs = torch.from_numpy(candidate_idcs)
70 if self.doc_is_single_vector:
71 candidate_doc_idcs = candidate_idcs.to(self.cumulative_doc_lengths.device)
72 else:
73 candidate_doc_idcs = torch.searchsorted(
74 self.cumulative_doc_lengths,
75 candidate_idcs.to(self.cumulative_doc_lengths.device),
76 side="right",
77 )
78 return candidate_scores, candidate_doc_idcs
79
80 def gather_imputation(
81 self, candidate_doc_idcs: torch.Tensor, query_lengths: torch.Tensor
82 ) -> Tuple[BiEncoderEmbedding, torch.Tensor, List[int]]:
83 # unique doc_idcs per query
84 doc_idcs_per_query = [
85 list(sorted(set(idcs.reshape(-1).tolist())))
86 for idcs in torch.split(candidate_doc_idcs, query_lengths.tolist())
87 ]
88 num_docs = [len(idcs) for idcs in doc_idcs_per_query]
89 doc_idcs = torch.tensor(sum(doc_idcs_per_query, [])).to(candidate_doc_idcs)
90 unique_doc_idcs, inverse_idcs = torch.unique(doc_idcs, return_inverse=True)
91
92 # gather all vectors for unique doc_idcs
93 doc_lengths = self.doc_lengths[unique_doc_idcs]
94 start_doc_idcs = self.cumulative_doc_lengths[unique_doc_idcs - 1]
95 start_doc_idcs[unique_doc_idcs == 0] = 0
96 all_doc_idcs = torch.cat(
97 [
98 torch.arange(start.item(), start.item() + length.item())
99 for start, length in zip(start_doc_idcs.cpu(), doc_lengths.cpu())
100 ]
101 )
102 all_doc_embeddings = torch.from_numpy(self.index.reconstruct_batch(all_doc_idcs))
103 unique_embeddings = torch.nn.utils.rnn.pad_sequence(
104 [embeddings for embeddings in torch.split(all_doc_embeddings, doc_lengths.tolist())],
105 batch_first=True,
106 ).to(inverse_idcs.device)
107 embeddings = unique_embeddings[inverse_idcs]
108
109 # mask out padding
110 doc_lengths = doc_lengths[inverse_idcs]
111 scoring_mask = torch.arange(embeddings.shape[1], device=embeddings.device) < doc_lengths[:, None]
112 doc_embeddings = BiEncoderEmbedding(embeddings=embeddings, scoring_mask=scoring_mask, encoding=None)
113 return doc_embeddings, doc_idcs, num_docs
114
115 def intra_ranking_imputation(
116 self,
117 candidate_scores: torch.Tensor,
118 candidate_doc_idcs: torch.Tensor,
119 query_lengths: torch.Tensor,
120 ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
121 max_query_length = int(query_lengths.max().item())
122 query_is_single_vector = max_query_length == 1
123
124 if self.doc_is_single_vector:
125 scores = candidate_scores.view(-1)
126 doc_idcs = candidate_doc_idcs.view(-1)
127 num_docs = torch.full((candidate_scores.shape[0],), candidate_scores.shape[1])
128 else:
129 # grab unique doc ids per query candidate
130 query_idcs = torch.arange(query_lengths.shape[0], device=query_lengths.device).repeat_interleave(
131 query_lengths
132 )
133 query_candidate_idcs = torch.cat(
134 [torch.arange(length.item(), device=query_lengths.device) for length in query_lengths]
135 )
136 paired_idcs = torch.stack(
137 [
138 query_idcs.repeat_interleave(candidate_scores.shape[1]),
139 query_candidate_idcs.repeat_interleave(candidate_scores.shape[1]),
140 candidate_doc_idcs.view(-1),
141 ]
142 ).T
143 unique_paired_idcs, inverse_idcs = torch.unique(paired_idcs[:, [0, 2]], return_inverse=True, dim=0)
144 doc_idcs = unique_paired_idcs[:, 1]
145 num_docs = unique_paired_idcs[:, 0].bincount()
146
147 # accumulate max score per doc
148 ranking_doc_idcs = torch.arange(doc_idcs.shape[0], device=query_lengths.device)[inverse_idcs]
149 idcs = ranking_doc_idcs * max_query_length + paired_idcs[:, 1]
150 shape = torch.Size((doc_idcs.shape[0], max_query_length))
151 scores = torch.scatter_reduce(
152 torch.full((shape.numel(),), float("inf"), device=query_lengths.device),
153 0,
154 idcs,
155 candidate_scores.view(-1).to(query_lengths.device),
156 "max",
157 include_self=False,
158 ).view(shape)
159
160 if query_is_single_vector:
161 scores = scores.squeeze(-1)
162 else:
163 # impute missing values
164 if self.search_config.imputation_strategy == "min":
165 impute_values = (
166 scores.masked_fill(scores == torch.finfo(scores.dtype).min, float("inf"))
167 .min(0, keepdim=True)
168 .values.expand_as(scores)
169 )
170 elif self.search_config.imputation_strategy is None:
171 impute_values = torch.zeros_like(scores)
172 else:
173 raise ValueError("Invalid imputation strategy: " f"{self.search_config.imputation_strategy}")
174 is_inf = torch.isinf(scores)
175 scores[is_inf] = impute_values[is_inf]
176
177 # aggregate score per query vector
178 mask = (
179 torch.arange(max_query_length, device=query_lengths.device) < query_lengths[:, None]
180 ).repeat_interleave(num_docs, dim=0)
181 scores = self.module.scoring_function._aggregate(
182 scores, mask, self.module.config.query_aggregation_function, dim=1
183 ).squeeze(-1)
184 return scores, doc_idcs, num_docs.tolist()
185
186
[docs]
187class FaissSearchConfig(SearchConfig):
188 search_class = FaissSearcher
189
[docs]
190 def __init__(
191 self,
192 k: int = 10,
193 candidate_k: int = 100,
194 imputation_strategy: Literal["min", "gather"] | None = None,
195 n_probe: int = 1,
196 ef_search: int = 16,
197 ) -> None:
198 super().__init__(k)
199 self.candidate_k = candidate_k
200 self.imputation_strategy = imputation_strategy
201 self.n_probe = n_probe
202 self.ef_search = ef_search