Source code for lightning_ir.retrieve.plaid.plaid_searcher

  1from __future__ import annotations
  2
  3from pathlib import Path
  4from typing import TYPE_CHECKING, List, Tuple, Type
  5
  6import torch
  7
  8from ...bi_encoder.model import BiEncoderEmbedding
  9from ..base.searcher import SearchConfig, Searcher
 10from .packed_tensor import PackedTensor
 11from .plaid_indexer import PlaidIndexConfig
 12from .residual_codec import ResidualCodec
 13
 14if TYPE_CHECKING:
 15    from ...bi_encoder import BiEncoderModule
 16
 17
[docs] 18class PlaidSearcher(Searcher):
[docs] 19 def __init__( 20 self, index_dir: Path | str, search_config: PlaidSearchConfig, module: BiEncoderModule, use_gpu: bool = False 21 ) -> None: 22 super().__init__(index_dir, search_config, module, use_gpu) 23 self.residual_codec = ResidualCodec.from_pretrained( 24 PlaidIndexConfig.from_pretrained(self.index_dir), self.index_dir 25 ) 26 27 self.codes = torch.load(self.index_dir / "codes.pt") 28 self.residuals = torch.load(self.index_dir / "residuals.pt").view(self.codes.shape[0], -1) 29 self.packed_codes = PackedTensor(self.codes, self.doc_lengths.tolist()) 30 self.packed_residuals = PackedTensor(self.residuals, self.doc_lengths.tolist()) 31 32 # code_idx to embedding_idcs mapping 33 sorted_codes, embedding_idcs = self.codes.sort() 34 num_embeddings_per_code = torch.bincount(sorted_codes, minlength=self.residual_codec.num_centroids).tolist() 35 # self.code_to_embedding_ivf = PackedTensor(embedding_idcs, num_embeddings_per_code) 36 37 # code_idx to doc_idcs mapping 38 embedding_idx_to_doc_idx = torch.arange(self.num_docs).repeat_interleave(self.doc_lengths) 39 full_doc_ivf = embedding_idx_to_doc_idx[embedding_idcs] 40 doc_ivf_lengths = [] 41 unique_doc_idcs = [] 42 for doc_idcs in full_doc_ivf.split(num_embeddings_per_code): 43 unique_doc_idcs.append(doc_idcs.unique()) 44 doc_ivf_lengths.append(unique_doc_idcs[-1].shape[0]) 45 self.code_to_doc_ivf = PackedTensor(torch.cat(unique_doc_idcs), doc_ivf_lengths) 46 47 # doc_idx to code_idcs mapping 48 sorted_doc_idcs, doc_idx_to_code_idx = self.code_to_doc_ivf.packed_tensor.sort() 49 code_idcs = torch.arange(self.residual_codec.num_centroids).repeat_interleave( 50 torch.tensor(self.code_to_doc_ivf.lengths) 51 )[doc_idx_to_code_idx] 52 num_codes_per_doc = torch.bincount(sorted_doc_idcs, minlength=self.num_docs) 53 self.doc_to_code_ivf = PackedTensor(code_idcs, num_codes_per_doc.tolist()) 54 55 self.search_config: PlaidSearchConfig
56 57 @property 58 def num_embeddings(self) -> int: 59 return int(self.cumulative_doc_lengths[-1].item()) 60 61 def candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, PackedTensor]: 62 # grab top `n_cells` neighbor cells for all query embeddings 63 # `num_queries x query_length x num_centroids` 64 scores = ( 65 query_embeddings.embeddings.to(self.residual_codec.centroids) 66 @ self.residual_codec.centroids.transpose(0, 1)[None] 67 ) 68 scores = scores.masked_fill(~query_embeddings.scoring_mask[..., None], 0) 69 _, codes = torch.topk(scores, self.search_config.n_cells, dim=-1, sorted=False) 70 packed_codes = codes[query_embeddings.scoring_mask].view(-1) 71 code_lengths = (query_embeddings.scoring_mask.sum(-1) * self.search_config.n_cells).tolist() 72 73 # grab document idcs for all cells 74 packed_doc_idcs = self.code_to_doc_ivf.lookup(packed_codes, code_lengths, unique=True) 75 return scores, packed_doc_idcs 76 77 def filter_candidates( 78 self, centroid_scores: torch.Tensor, doc_idcs: PackedTensor, threshold: float | None, k: int 79 ) -> PackedTensor: 80 num_query_vecs = centroid_scores.shape[1] 81 num_centroids = centroid_scores.shape[-1] 82 83 # repeat query centroid scores for each document 84 # `num_docs x num_query_vecs x num_centroids + 1` 85 # NOTE we pad values such that the codes with -1 padding index 0 values 86 expanded_centroid_scores = torch.nn.functional.pad( 87 centroid_scores.repeat_interleave(torch.tensor(doc_idcs.lengths), dim=0), (0, 1) 88 ) 89 90 # grab codes for each document 91 code_idcs = self.doc_to_code_ivf.lookup(doc_idcs.packed_tensor, 1) 92 # `num_docs x max_num_codes_per_doc` 93 padded_codes = code_idcs.to_padded_tensor(pad_value=num_centroids) 94 mask = padded_codes != num_centroids 95 # `num_docs x max_num_query_vecs x max_num_codes_per_doc` 96 padded_codes = padded_codes[:, None].expand(-1, num_query_vecs, -1) 97 98 # apply pruning threshold 99 if threshold is not None and threshold: 100 expanded_centroid_scores = expanded_centroid_scores.masked_fill( 101 expanded_centroid_scores.amax(1, keepdim=True) < threshold, 0 102 ) 103 104 # NOTE this is colbert scoring, but instead of using the doc embeddings we use the centroid scores 105 # expanded_centroid_scores: `num_docs x max_num_query_vecs x num_centroids + 1 ` 106 # padded_codes: `num_docs x max_num_query_vecs x max_num_codes_per_doc` 107 # approx_similarity: `num_docs x max_num_query_vecs x max_num_codes_per_doc` 108 approx_similarity = torch.gather(input=expanded_centroid_scores, dim=-1, index=padded_codes) 109 approx_scores = self.module.scoring_function.aggregate_similarity( 110 approx_similarity, query_scoring_mask=None, doc_scoring_mask=mask[:, None] 111 ) 112 113 filtered_doc_idcs = [] 114 lengths = [] 115 iterator = zip(doc_idcs.packed_tensor.split(doc_idcs.lengths), approx_scores.split(doc_idcs.lengths)) 116 for doc_idcs, doc_scores in iterator: 117 if doc_scores.shape[0] <= k: 118 filtered_doc_idcs.append(doc_idcs) 119 else: 120 filtered_doc_idcs.append(doc_idcs[torch.topk(doc_scores, k, sorted=False)]) 121 lengths.append(filtered_doc_idcs[-1].shape[0]) 122 123 packed_filtered_doc_idcs = PackedTensor(torch.cat(filtered_doc_idcs), lengths) 124 125 return packed_filtered_doc_idcs 126 127 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: 128 query_embeddings = query_embeddings.to(self.device) 129 centroid_scores, doc_idcs = self.candidate_retrieval(query_embeddings) 130 # NOTE no idea why we do two filter steps (the first with a threshold, the second without) 131 # filter step 1 132 filtered_doc_idcs = self.filter_candidates( 133 centroid_scores, doc_idcs, self.search_config.centroid_score_threshold, self.search_config.candidate_k 134 ) 135 # filter step 2 136 filtered_doc_idcs = self.filter_candidates( 137 centroid_scores, filtered_doc_idcs, None, self.search_config.candidate_k // 4 138 ) 139 140 # gather/decompress document embeddings 141 doc_embedding_codes = self.packed_codes.lookup(filtered_doc_idcs.packed_tensor, 1) 142 doc_embedding_residuals = self.packed_residuals.lookup(filtered_doc_idcs.packed_tensor, 1) 143 doc_embeddings = self.residual_codec.decompress(doc_embedding_codes, doc_embedding_residuals) 144 padded_doc_embeddings = doc_embeddings.to_padded_tensor() 145 doc_scoring_mask = padded_doc_embeddings[..., 0] != 0 146 147 # compute scores 148 num_docs = filtered_doc_idcs.lengths 149 doc_scores = self.module.scoring_function.forward( 150 query_embeddings, 151 BiEncoderEmbedding(padded_doc_embeddings, doc_scoring_mask, None), 152 num_docs, 153 ) 154 return doc_scores, filtered_doc_idcs.packed_tensor, num_docs
155 156
[docs] 157class PlaidSearchConfig(SearchConfig): 158 159 search_class: Type[Searcher] = PlaidSearcher 160
[docs] 161 def __init__( 162 self, 163 k: int, 164 candidate_k: int | None = None, 165 n_cells: int | None = None, 166 centroid_score_threshold: float | None = None, 167 ) -> None: 168 # https://github.com/stanford-futuredata/ColBERT/blob/7067ef598b5011edaa1f4a731a2c269dbac864e4/colbert/searcher.py#L106 169 super().__init__(k) 170 if candidate_k is None: 171 if k <= 10: 172 candidate_k = 256 173 elif k <= 100: 174 candidate_k = 1_024 175 else: 176 candidate_k = max(k * 4, 4_096) 177 self.candidate_k = candidate_k 178 if n_cells is None: 179 if k <= 10: 180 n_cells = 1 181 elif k <= 100: 182 n_cells = 2 183 else: 184 n_cells = 4 185 self.n_cells = n_cells 186 if centroid_score_threshold is None: 187 if k <= 10: 188 centroid_score_threshold = 0.5 189 elif k <= 100: 190 centroid_score_threshold = 0.45 191 else: 192 centroid_score_threshold = 0.4 193 self.centroid_score_threshold = centroid_score_threshold