Source code for lightning_ir.retrieve.faiss.faiss_searcher

  1from __future__ import annotations
  2
  3from pathlib import Path
  4from typing import TYPE_CHECKING, List, Literal, Tuple
  5
  6import torch
  7
  8from ...bi_encoder.model import BiEncoderEmbedding
  9from ..base import SearchConfig, Searcher
 10
 11if TYPE_CHECKING:
 12    from ...bi_encoder import BiEncoderModule
 13
 14
[docs] 15class FaissSearcher(Searcher):
[docs] 16 def __init__( 17 self, 18 index_dir: Path | str, 19 search_config: FaissSearchConfig, 20 module: BiEncoderModule, 21 use_gpu: bool = False, 22 ) -> None: 23 import faiss 24 25 self.search_config: FaissSearchConfig 26 self.index = faiss.read_index(str(Path(index_dir) / "index.faiss")) 27 if use_gpu and hasattr(faiss, "index_cpu_to_all_gpus"): 28 self.index = faiss.index_cpu_to_all_gpus(self.index) 29 ivf_index = None 30 try: 31 ivf_index = faiss.extract_index_ivf(self.index) 32 except RuntimeError: 33 pass 34 if ivf_index is not None: 35 ivf_index.nprobe = search_config.n_probe 36 quantizer = getattr(ivf_index, "quantizer", None) 37 if quantizer is not None: 38 downcasted_quantizer = faiss.downcast_index(quantizer) 39 hnsw = getattr(downcasted_quantizer, "hnsw", None) 40 if hnsw is not None: 41 hnsw.efSearch = search_config.ef_search 42 super().__init__(index_dir, search_config, module, use_gpu)
43 44 @property 45 def num_embeddings(self) -> int: 46 return self.index.ntotal 47 48 @property 49 def doc_is_single_vector(self) -> bool: 50 return self.num_docs == self.num_embeddings 51 52 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: 53 query_embeddings = query_embeddings.to(self.device) 54 candidate_scores, candidate_doc_idcs = self.candidate_retrieval(query_embeddings) 55 query_lengths = query_embeddings.scoring_mask.sum(-1) 56 if self.search_config.imputation_strategy == "gather": 57 doc_embeddings, doc_idcs, num_docs = self.gather_imputation(candidate_doc_idcs, query_lengths) 58 doc_scores = self.module.model.score(query_embeddings, doc_embeddings, num_docs) 59 else: 60 doc_scores, doc_idcs, num_docs = self.intra_ranking_imputation( 61 candidate_scores, candidate_doc_idcs, query_lengths 62 ) 63 return doc_scores, doc_idcs, num_docs 64 65 def candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor]: 66 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 67 candidate_scores, candidate_idcs = self.index.search(embeddings.float().cpu(), self.search_config.candidate_k) 68 candidate_scores = torch.from_numpy(candidate_scores) 69 candidate_idcs = torch.from_numpy(candidate_idcs) 70 if self.doc_is_single_vector: 71 candidate_doc_idcs = candidate_idcs.to(self.cumulative_doc_lengths.device) 72 else: 73 candidate_doc_idcs = torch.searchsorted( 74 self.cumulative_doc_lengths, 75 candidate_idcs.to(self.cumulative_doc_lengths.device), 76 side="right", 77 ) 78 return candidate_scores, candidate_doc_idcs 79 80 def gather_imputation( 81 self, candidate_doc_idcs: torch.Tensor, query_lengths: torch.Tensor 82 ) -> Tuple[BiEncoderEmbedding, torch.Tensor, List[int]]: 83 # unique doc_idcs per query 84 doc_idcs_per_query = [ 85 list(sorted(set(idcs.reshape(-1).tolist()))) 86 for idcs in torch.split(candidate_doc_idcs, query_lengths.tolist()) 87 ] 88 num_docs = [len(idcs) for idcs in doc_idcs_per_query] 89 doc_idcs = torch.tensor(sum(doc_idcs_per_query, [])).to(candidate_doc_idcs) 90 unique_doc_idcs, inverse_idcs = torch.unique(doc_idcs, return_inverse=True) 91 92 # gather all vectors for unique doc_idcs 93 doc_lengths = self.doc_lengths[unique_doc_idcs] 94 start_doc_idcs = self.cumulative_doc_lengths[unique_doc_idcs - 1] 95 start_doc_idcs[unique_doc_idcs == 0] = 0 96 all_doc_idcs = torch.cat( 97 [ 98 torch.arange(start.item(), start.item() + length.item()) 99 for start, length in zip(start_doc_idcs.cpu(), doc_lengths.cpu()) 100 ] 101 ) 102 all_doc_embeddings = torch.from_numpy(self.index.reconstruct_batch(all_doc_idcs)) 103 unique_embeddings = torch.nn.utils.rnn.pad_sequence( 104 [embeddings for embeddings in torch.split(all_doc_embeddings, doc_lengths.tolist())], 105 batch_first=True, 106 ).to(inverse_idcs.device) 107 embeddings = unique_embeddings[inverse_idcs] 108 109 # mask out padding 110 doc_lengths = doc_lengths[inverse_idcs] 111 scoring_mask = torch.arange(embeddings.shape[1], device=embeddings.device) < doc_lengths[:, None] 112 doc_embeddings = BiEncoderEmbedding(embeddings=embeddings, scoring_mask=scoring_mask, encoding=None) 113 return doc_embeddings, doc_idcs, num_docs 114 115 def intra_ranking_imputation( 116 self, 117 candidate_scores: torch.Tensor, 118 candidate_doc_idcs: torch.Tensor, 119 query_lengths: torch.Tensor, 120 ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: 121 max_query_length = int(query_lengths.max().item()) 122 query_is_single_vector = max_query_length == 1 123 124 if self.doc_is_single_vector: 125 scores = candidate_scores.view(-1) 126 doc_idcs = candidate_doc_idcs.view(-1) 127 num_docs = torch.full((candidate_scores.shape[0],), candidate_scores.shape[1]) 128 else: 129 # grab unique doc ids per query candidate 130 query_idcs = torch.arange(query_lengths.shape[0], device=query_lengths.device).repeat_interleave( 131 query_lengths 132 ) 133 query_candidate_idcs = torch.cat( 134 [torch.arange(length.item(), device=query_lengths.device) for length in query_lengths] 135 ) 136 paired_idcs = torch.stack( 137 [ 138 query_idcs.repeat_interleave(candidate_scores.shape[1]), 139 query_candidate_idcs.repeat_interleave(candidate_scores.shape[1]), 140 candidate_doc_idcs.view(-1), 141 ] 142 ).T 143 unique_paired_idcs, inverse_idcs = torch.unique(paired_idcs[:, [0, 2]], return_inverse=True, dim=0) 144 doc_idcs = unique_paired_idcs[:, 1] 145 num_docs = unique_paired_idcs[:, 0].bincount() 146 147 # accumulate max score per doc 148 ranking_doc_idcs = torch.arange(doc_idcs.shape[0], device=query_lengths.device)[inverse_idcs] 149 idcs = ranking_doc_idcs * max_query_length + paired_idcs[:, 1] 150 shape = torch.Size((doc_idcs.shape[0], max_query_length)) 151 scores = torch.scatter_reduce( 152 torch.full((shape.numel(),), float("inf"), device=query_lengths.device), 153 0, 154 idcs, 155 candidate_scores.view(-1).to(query_lengths.device), 156 "max", 157 include_self=False, 158 ).view(shape) 159 160 if query_is_single_vector: 161 scores = scores.squeeze(-1) 162 else: 163 # impute missing values 164 if self.search_config.imputation_strategy == "min": 165 impute_values = ( 166 scores.masked_fill(scores == torch.finfo(scores.dtype).min, float("inf")) 167 .min(0, keepdim=True) 168 .values.expand_as(scores) 169 ) 170 elif self.search_config.imputation_strategy is None: 171 impute_values = torch.zeros_like(scores) 172 else: 173 raise ValueError("Invalid imputation strategy: " f"{self.search_config.imputation_strategy}") 174 is_inf = torch.isinf(scores) 175 scores[is_inf] = impute_values[is_inf] 176 177 # aggregate score per query vector 178 mask = ( 179 torch.arange(max_query_length, device=query_lengths.device) < query_lengths[:, None] 180 ).repeat_interleave(num_docs, dim=0) 181 scores = self.module.scoring_function._aggregate( 182 scores, mask, self.module.config.query_aggregation_function, dim=1 183 ).squeeze(-1) 184 return scores, doc_idcs, num_docs.tolist()
185 186
[docs] 187class FaissSearchConfig(SearchConfig): 188 search_class = FaissSearcher 189
[docs] 190 def __init__( 191 self, 192 k: int = 10, 193 candidate_k: int = 100, 194 imputation_strategy: Literal["min", "gather"] | None = None, 195 n_probe: int = 1, 196 ef_search: int = 16, 197 ) -> None: 198 super().__init__(k) 199 self.candidate_k = candidate_k 200 self.imputation_strategy = imputation_strategy 201 self.n_probe = n_probe 202 self.ef_search = ef_search