1import warnings
2from pathlib import Path
3
4import torch
5
6from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
7from ...data import IndexBatch
8from ..base import IndexConfig, Indexer
9
10
[docs]
11class FaissIndexer(Indexer):
12 INDEX_FACTORY: str
13
[docs]
14 def __init__(
15 self,
16 index_dir: Path,
17 index_config: "FaissIndexConfig",
18 bi_encoder_config: BiEncoderConfig,
19 verbose: bool = False,
20 ) -> None:
21 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
22 import faiss
23
24 similarity_function = bi_encoder_config.similarity_function
25 if similarity_function in ("cosine", "dot"):
26 self.metric_type = faiss.METRIC_INNER_PRODUCT
27 else:
28 raise ValueError(f"similarity_function {similarity_function} unknown")
29
30 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict())
31 if similarity_function == "cosine":
32 index_factory = "L2norm," + index_factory
33 self.index = faiss.index_factory(self.bi_encoder_config.embedding_dim, index_factory, self.metric_type)
34
35 self.set_verbosity()
36
37 if torch.cuda.is_available():
38 self.to_gpu()
39
40 def to_gpu(self) -> None:
41 pass
42
43 def to_cpu(self) -> None:
44 pass
45
46 def set_verbosity(self, verbose: bool | None = None) -> None:
47 self.index.verbose = self.verbose if verbose is None else verbose
48
49 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
50 return embeddings
51
52 def save(self) -> None:
53 super().save()
54 import faiss
55
56 if self.num_embeddings != self.index.ntotal:
57 raise ValueError("number of embeddings does not match index.ntotal")
58 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"):
59 self.index = faiss.index_gpu_to_cpu(self.index)
60
61 faiss.write_index(self.index, str(self.index_dir / "index.faiss"))
62
63 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
64 doc_embeddings = output.doc_embeddings
65 if doc_embeddings is None:
66 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
67 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
68 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
69 doc_ids = index_batch.doc_ids
70 embeddings = self.process_embeddings(embeddings)
71
72 if embeddings.shape[0]:
73 self.index.add(embeddings.float().cpu())
74
75 self.num_embeddings += embeddings.shape[0]
76 self.num_docs += len(doc_ids)
77
78 self.doc_lengths.extend(doc_lengths.cpu().tolist())
79 self.doc_ids.extend(doc_ids)
80
81
[docs]
82class FaissFlatIndexer(FaissIndexer):
83 INDEX_FACTORY = "Flat"
84
[docs]
85 def __init__(
86 self,
87 index_dir: Path,
88 index_config: "FaissFlatIndexConfig",
89 bi_encoder_config: BiEncoderConfig,
90 verbose: bool = False,
91 ) -> None:
92 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
93 self.index_config: FaissFlatIndexConfig
94
95 def to_gpu(self) -> None:
96 pass
97
98 def to_cpu(self) -> None:
99 pass
100
101
102class _FaissTrainIndexer(FaissIndexer):
103
104 INDEX_FACTORY = "" # class only acts as mixin
105
106 def __init__(
107 self,
108 index_dir: Path,
109 index_config: "_FaissTrainIndexConfig",
110 bi_encoder_config: BiEncoderConfig,
111 verbose: bool = False,
112 ) -> None:
113 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
114 if index_config.num_train_embeddings is None:
115 raise ValueError("num_train_embeddings must be set")
116 self.num_train_embeddings = index_config.num_train_embeddings
117
118 self._train_embeddings: torch.Tensor | None = torch.full(
119 (self.num_train_embeddings, self.bi_encoder_config.embedding_dim), torch.nan, dtype=torch.float32
120 )
121
122 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
123 embeddings = self._grab_train_embeddings(embeddings)
124 self._train()
125 return embeddings
126
127 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
128 if self._train_embeddings is not None:
129 # save training embeddings until num_train_embeddings is reached
130 # if num_train_embeddings overflows, save the remaining embeddings
131 start = self.num_embeddings
132 end = start + embeddings.shape[0]
133 end = min(self.num_train_embeddings, start + embeddings.shape[0])
134 length = end - start
135 self._train_embeddings[start:end] = embeddings[:length]
136 self.num_embeddings += length
137 embeddings = embeddings[length:]
138 return embeddings
139
140 def _train(self, force: bool = False):
141 if self._train_embeddings is None:
142 return
143 if not force and self.num_embeddings < self.num_train_embeddings:
144 return
145 if torch.isnan(self._train_embeddings).any():
146 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.")
147 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
148 self.index.train(self._train_embeddings)
149 if torch.cuda.is_available():
150 self.to_cpu()
151 self.index.add(self._train_embeddings)
152 self._train_embeddings = None
153 self.set_verbosity(False)
154
155 def save(self) -> None:
156 if not self.index.is_trained:
157 self._train(force=True)
158 return super().save()
159
160
[docs]
161class FaissIVFIndexer(_FaissTrainIndexer):
162 INDEX_FACTORY = "IVF{num_centroids},Flat"
163
[docs]
164 def __init__(
165 self,
166 index_dir: Path,
167 index_config: "FaissIVFIndexConfig",
168 bi_encoder_config: BiEncoderConfig,
169 verbose: bool = False,
170 ) -> None:
171 # default faiss values
172 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43
173 max_points_per_centroid = 256
174 index_config.num_train_embeddings = (
175 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid
176 )
177 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
178
179 import faiss
180
181 ivf_index = faiss.extract_index_ivf(self.index)
182 if hasattr(ivf_index, "quantizer"):
183 quantizer = ivf_index.quantizer
184 if hasattr(faiss.downcast_index(quantizer), "hnsw"):
185 downcasted_quantizer = faiss.downcast_index(quantizer)
186 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
187
188 def to_gpu(self) -> None:
189 import faiss
190
191 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu
192 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html
193 clustering_index = faiss.index_cpu_to_all_gpus(
194 faiss.IndexFlat(self.bi_encoder_config.embedding_dim, self.metric_type)
195 )
196 clustering_index.verbose = self.verbose
197 index_ivf = faiss.extract_index_ivf(self.index)
198 index_ivf.clustering_index = clustering_index
199
200 def to_cpu(self) -> None:
201 import faiss
202
203 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") and hasattr(faiss, "index_cpu_to_gpu"):
204 self.index = faiss.index_gpu_to_cpu(self.index)
205
206 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e
207 # move index to cpu but leave quantizer on gpu
208 index_ivf = faiss.extract_index_ivf(self.index)
209 quantizer = index_ivf.quantizer
210 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer)
211 index_ivf.quantizer = gpu_quantizer
212
213 def set_verbosity(self, verbose: bool | None = None) -> None:
214 import faiss
215
216 verbose = verbose if verbose is not None else self.verbose
217 index = faiss.extract_index_ivf(self.index)
218 for elem in (index, index.quantizer):
219 setattr(elem, "verbose", verbose)
220
221
[docs]
222class FaissPQIndexer(_FaissTrainIndexer):
223
224 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}"
225
[docs]
226 def __init__(
227 self,
228 index_dir: Path,
229 index_config: "FaissPQIndexConfig",
230 bi_encoder_config: BiEncoderConfig,
231 verbose: bool = False,
232 ) -> None:
233 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
234 self.index_config: FaissPQIndexConfig
235
236 def to_gpu(self) -> None:
237 pass
238
239 def to_cpu(self) -> None:
240 pass
241
242
[docs]
243class FaissIVFPQIndexer(FaissIVFIndexer):
244 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}"
245
[docs]
246 def __init__(
247 self,
248 index_dir: Path,
249 index_config: "FaissIVFPQIndexConfig",
250 bi_encoder_config: BiEncoderConfig,
251 verbose: bool = False,
252 ) -> None:
253 import faiss
254
255 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
256 self.index_config: FaissIVFPQIndexConfig
257
258 index_ivf = faiss.extract_index_ivf(self.index)
259 index_ivf.make_direct_map()
260
261 def set_verbosity(self, verbose: bool | None = None) -> None:
262 super().set_verbosity(verbose)
263 import faiss
264
265 verbose = verbose if verbose is not None else self.verbose
266 index_ivf_pq = faiss.downcast_index(self.index.index)
267 for elem in (
268 index_ivf_pq.pq,
269 index_ivf_pq.quantizer,
270 ):
271 setattr(elem, "verbose", verbose)
272
273
[docs]
274class FaissIndexConfig(IndexConfig):
275 indexer_class = FaissIndexer
276
277
[docs]
278class FaissFlatIndexConfig(FaissIndexConfig):
279 indexer_class = FaissFlatIndexer
280
281
282class _FaissTrainIndexConfig(FaissIndexConfig):
283
284 indexer_class = _FaissTrainIndexer
285
286 def __init__(self, num_train_embeddings: int | None = None) -> None:
287 super().__init__()
288 self.num_train_embeddings = num_train_embeddings
289
290
[docs]
291class FaissIVFIndexConfig(_FaissTrainIndexConfig):
292 indexer_class = FaissIVFIndexer
293
[docs]
294 def __init__(
295 self,
296 num_train_embeddings: int | None = None,
297 num_centroids: int = 262144,
298 ef_construction: int = 40,
299 ) -> None:
300 super().__init__(num_train_embeddings)
301 self.num_centroids = num_centroids
302 self.ef_construction = ef_construction
303
304
[docs]
305class FaissPQIndexConfig(_FaissTrainIndexConfig):
306 indexer_class = FaissPQIndexer
307
[docs]
308 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None:
309 super().__init__(num_train_embeddings)
310 self.num_subquantizers = num_subquantizers
311 self.n_bits = n_bits
312
313
[docs]
314class FaissIVFPQIndexConfig(FaissIVFIndexConfig):
315 indexer_class = FaissIVFPQIndexer
316
[docs]
317 def __init__(
318 self,
319 num_train_embeddings: int | None = None,
320 num_centroids: int = 262144,
321 ef_construction: int = 40,
322 num_subquantizers: int = 16,
323 n_bits: int = 8,
324 ) -> None:
325 super().__init__(num_train_embeddings, num_centroids, ef_construction)
326 self.num_subquantizers = num_subquantizers
327 self.n_bits = n_bits