Source code for lightning_ir.retrieve.faiss.faiss_indexer

  1import warnings
  2from pathlib import Path
  3
  4import torch
  5
  6from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
  7from ...data import IndexBatch
  8from ..base import IndexConfig, Indexer
  9
 10
[docs] 11class FaissIndexer(Indexer): 12 INDEX_FACTORY: str 13
[docs] 14 def __init__( 15 self, 16 index_dir: Path, 17 index_config: "FaissIndexConfig", 18 bi_encoder_config: BiEncoderConfig, 19 verbose: bool = False, 20 ) -> None: 21 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 22 import faiss 23 24 similarity_function = bi_encoder_config.similarity_function 25 if similarity_function in ("cosine", "dot"): 26 self.metric_type = faiss.METRIC_INNER_PRODUCT 27 else: 28 raise ValueError(f"similarity_function {similarity_function} unknown") 29 30 index_factory = self.INDEX_FACTORY.format(**index_config.to_dict()) 31 if similarity_function == "cosine": 32 index_factory = "L2norm," + index_factory 33 self.index = faiss.index_factory(self.bi_encoder_config.embedding_dim, index_factory, self.metric_type) 34 35 self.set_verbosity() 36 37 if torch.cuda.is_available(): 38 self.to_gpu()
39 40 def to_gpu(self) -> None: 41 pass 42 43 def to_cpu(self) -> None: 44 pass 45 46 def set_verbosity(self, verbose: bool | None = None) -> None: 47 self.index.verbose = self.verbose if verbose is None else verbose 48 49 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 50 return embeddings 51 52 def save(self) -> None: 53 super().save() 54 import faiss 55 56 if self.num_embeddings != self.index.ntotal: 57 raise ValueError("number of embeddings does not match index.ntotal") 58 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu"): 59 self.index = faiss.index_gpu_to_cpu(self.index) 60 61 faiss.write_index(self.index, str(self.index_dir / "index.faiss")) 62 63 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 64 doc_embeddings = output.doc_embeddings 65 if doc_embeddings is None: 66 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 67 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 68 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 69 doc_ids = index_batch.doc_ids 70 embeddings = self.process_embeddings(embeddings) 71 72 if embeddings.shape[0]: 73 self.index.add(embeddings.float().cpu()) 74 75 self.num_embeddings += embeddings.shape[0] 76 self.num_docs += len(doc_ids) 77 78 self.doc_lengths.extend(doc_lengths.cpu().tolist()) 79 self.doc_ids.extend(doc_ids)
80 81
[docs] 82class FaissFlatIndexer(FaissIndexer): 83 INDEX_FACTORY = "Flat" 84
[docs] 85 def __init__( 86 self, 87 index_dir: Path, 88 index_config: "FaissFlatIndexConfig", 89 bi_encoder_config: BiEncoderConfig, 90 verbose: bool = False, 91 ) -> None: 92 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 93 self.index_config: FaissFlatIndexConfig
94 95 def to_gpu(self) -> None: 96 pass 97 98 def to_cpu(self) -> None: 99 pass
100 101 102class _FaissTrainIndexer(FaissIndexer): 103 104 INDEX_FACTORY = "" # class only acts as mixin 105 106 def __init__( 107 self, 108 index_dir: Path, 109 index_config: "_FaissTrainIndexConfig", 110 bi_encoder_config: BiEncoderConfig, 111 verbose: bool = False, 112 ) -> None: 113 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 114 if index_config.num_train_embeddings is None: 115 raise ValueError("num_train_embeddings must be set") 116 self.num_train_embeddings = index_config.num_train_embeddings 117 118 self._train_embeddings: torch.Tensor | None = torch.full( 119 (self.num_train_embeddings, self.bi_encoder_config.embedding_dim), torch.nan, dtype=torch.float32 120 ) 121 122 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 123 embeddings = self._grab_train_embeddings(embeddings) 124 self._train() 125 return embeddings 126 127 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 128 if self._train_embeddings is not None: 129 # save training embeddings until num_train_embeddings is reached 130 # if num_train_embeddings overflows, save the remaining embeddings 131 start = self.num_embeddings 132 end = start + embeddings.shape[0] 133 end = min(self.num_train_embeddings, start + embeddings.shape[0]) 134 length = end - start 135 self._train_embeddings[start:end] = embeddings[:length] 136 self.num_embeddings += length 137 embeddings = embeddings[length:] 138 return embeddings 139 140 def _train(self, force: bool = False): 141 if self._train_embeddings is None: 142 return 143 if not force and self.num_embeddings < self.num_train_embeddings: 144 return 145 if torch.isnan(self._train_embeddings).any(): 146 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.") 147 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 148 self.index.train(self._train_embeddings) 149 if torch.cuda.is_available(): 150 self.to_cpu() 151 self.index.add(self._train_embeddings) 152 self._train_embeddings = None 153 self.set_verbosity(False) 154 155 def save(self) -> None: 156 if not self.index.is_trained: 157 self._train(force=True) 158 return super().save() 159 160
[docs] 161class FaissIVFIndexer(_FaissTrainIndexer): 162 INDEX_FACTORY = "IVF{num_centroids},Flat" 163
[docs] 164 def __init__( 165 self, 166 index_dir: Path, 167 index_config: "FaissIVFIndexConfig", 168 bi_encoder_config: BiEncoderConfig, 169 verbose: bool = False, 170 ) -> None: 171 # default faiss values 172 # https://github.com/facebookresearch/faiss/blob/dafdff110489db7587b169a0afee8470f220d295/faiss/Clustering.h#L43 173 max_points_per_centroid = 256 174 index_config.num_train_embeddings = ( 175 index_config.num_train_embeddings or index_config.num_centroids * max_points_per_centroid 176 ) 177 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 178 179 import faiss 180 181 ivf_index = faiss.extract_index_ivf(self.index) 182 if hasattr(ivf_index, "quantizer"): 183 quantizer = ivf_index.quantizer 184 if hasattr(faiss.downcast_index(quantizer), "hnsw"): 185 downcasted_quantizer = faiss.downcast_index(quantizer) 186 downcasted_quantizer.hnsw.efConstruction = index_config.ef_construction
187 188 def to_gpu(self) -> None: 189 import faiss 190 191 # clustering_index overrides the index used during clustering but leaves the quantizer on the gpu 192 # https://faiss.ai/cpp_api/namespace/namespacefaiss_1_1gpu.html 193 clustering_index = faiss.index_cpu_to_all_gpus( 194 faiss.IndexFlat(self.bi_encoder_config.embedding_dim, self.metric_type) 195 ) 196 clustering_index.verbose = self.verbose 197 index_ivf = faiss.extract_index_ivf(self.index) 198 index_ivf.clustering_index = clustering_index 199 200 def to_cpu(self) -> None: 201 import faiss 202 203 if torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") and hasattr(faiss, "index_cpu_to_gpu"): 204 self.index = faiss.index_gpu_to_cpu(self.index) 205 206 # https://gist.github.com/mdouze/334ad6a979ac3637f6d95e9091356d3e 207 # move index to cpu but leave quantizer on gpu 208 index_ivf = faiss.extract_index_ivf(self.index) 209 quantizer = index_ivf.quantizer 210 gpu_quantizer = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, quantizer) 211 index_ivf.quantizer = gpu_quantizer 212 213 def set_verbosity(self, verbose: bool | None = None) -> None: 214 import faiss 215 216 verbose = verbose if verbose is not None else self.verbose 217 index = faiss.extract_index_ivf(self.index) 218 for elem in (index, index.quantizer): 219 setattr(elem, "verbose", verbose)
220 221
[docs] 222class FaissPQIndexer(_FaissTrainIndexer): 223 224 INDEX_FACTORY = "OPQ{num_subquantizers},PQ{num_subquantizers}x{n_bits}" 225
[docs] 226 def __init__( 227 self, 228 index_dir: Path, 229 index_config: "FaissPQIndexConfig", 230 bi_encoder_config: BiEncoderConfig, 231 verbose: bool = False, 232 ) -> None: 233 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 234 self.index_config: FaissPQIndexConfig
235 236 def to_gpu(self) -> None: 237 pass 238 239 def to_cpu(self) -> None: 240 pass
241 242
[docs] 243class FaissIVFPQIndexer(FaissIVFIndexer): 244 INDEX_FACTORY = "OPQ{num_subquantizers},IVF{num_centroids}_HNSW32,PQ{num_subquantizers}x{n_bits}" 245
[docs] 246 def __init__( 247 self, 248 index_dir: Path, 249 index_config: "FaissIVFPQIndexConfig", 250 bi_encoder_config: BiEncoderConfig, 251 verbose: bool = False, 252 ) -> None: 253 import faiss 254 255 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 256 self.index_config: FaissIVFPQIndexConfig 257 258 index_ivf = faiss.extract_index_ivf(self.index) 259 index_ivf.make_direct_map()
260 261 def set_verbosity(self, verbose: bool | None = None) -> None: 262 super().set_verbosity(verbose) 263 import faiss 264 265 verbose = verbose if verbose is not None else self.verbose 266 index_ivf_pq = faiss.downcast_index(self.index.index) 267 for elem in ( 268 index_ivf_pq.pq, 269 index_ivf_pq.quantizer, 270 ): 271 setattr(elem, "verbose", verbose)
272 273
[docs] 274class FaissIndexConfig(IndexConfig): 275 indexer_class = FaissIndexer
276 277
[docs] 278class FaissFlatIndexConfig(FaissIndexConfig): 279 indexer_class = FaissFlatIndexer
280 281 282class _FaissTrainIndexConfig(FaissIndexConfig): 283 284 indexer_class = _FaissTrainIndexer 285 286 def __init__(self, num_train_embeddings: int | None = None) -> None: 287 super().__init__() 288 self.num_train_embeddings = num_train_embeddings 289 290
[docs] 291class FaissIVFIndexConfig(_FaissTrainIndexConfig): 292 indexer_class = FaissIVFIndexer 293
[docs] 294 def __init__( 295 self, 296 num_train_embeddings: int | None = None, 297 num_centroids: int = 262144, 298 ef_construction: int = 40, 299 ) -> None: 300 super().__init__(num_train_embeddings) 301 self.num_centroids = num_centroids 302 self.ef_construction = ef_construction
303 304
[docs] 305class FaissPQIndexConfig(_FaissTrainIndexConfig): 306 indexer_class = FaissPQIndexer 307
[docs] 308 def __init__(self, num_train_embeddings: int | None = None, num_subquantizers: int = 16, n_bits: int = 8) -> None: 309 super().__init__(num_train_embeddings) 310 self.num_subquantizers = num_subquantizers 311 self.n_bits = n_bits
312 313
[docs] 314class FaissIVFPQIndexConfig(FaissIVFIndexConfig): 315 indexer_class = FaissIVFPQIndexer 316
[docs] 317 def __init__( 318 self, 319 num_train_embeddings: int | None = None, 320 num_centroids: int = 262144, 321 ef_construction: int = 40, 322 num_subquantizers: int = 16, 323 n_bits: int = 8, 324 ) -> None: 325 super().__init__(num_train_embeddings, num_centroids, ef_construction) 326 self.num_subquantizers = num_subquantizers 327 self.n_bits = n_bits