1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, List, Tuple, Type
5
6import torch
7
8from ...bi_encoder.model import BiEncoderEmbedding
9from ..base.searcher import SearchConfig, Searcher
10from .packed_tensor import PackedTensor
11from .plaid_indexer import PlaidIndexConfig
12from .residual_codec import ResidualCodec
13
14if TYPE_CHECKING:
15 from ...bi_encoder import BiEncoderModule
16
17
[docs]
18class PlaidSearcher(Searcher):
[docs]
19 def __init__(
20 self, index_dir: Path | str, search_config: PlaidSearchConfig, module: BiEncoderModule, use_gpu: bool = False
21 ) -> None:
22 super().__init__(index_dir, search_config, module, use_gpu)
23 self.residual_codec = ResidualCodec.from_pretrained(
24 PlaidIndexConfig.from_pretrained(self.index_dir), self.index_dir
25 )
26
27 self.codes = torch.load(self.index_dir / "codes.pt")
28 self.residuals = torch.load(self.index_dir / "residuals.pt").view(self.codes.shape[0], -1)
29 self.packed_codes = PackedTensor(self.codes, self.doc_lengths.tolist())
30 self.packed_residuals = PackedTensor(self.residuals, self.doc_lengths.tolist())
31
32 # code_idx to embedding_idcs mapping
33 sorted_codes, embedding_idcs = self.codes.sort()
34 num_embeddings_per_code = torch.bincount(sorted_codes, minlength=self.residual_codec.num_centroids).tolist()
35 # self.code_to_embedding_ivf = PackedTensor(embedding_idcs, num_embeddings_per_code)
36
37 # code_idx to doc_idcs mapping
38 embedding_idx_to_doc_idx = torch.arange(self.num_docs).repeat_interleave(self.doc_lengths)
39 full_doc_ivf = embedding_idx_to_doc_idx[embedding_idcs]
40 doc_ivf_lengths = []
41 unique_doc_idcs = []
42 for doc_idcs in full_doc_ivf.split(num_embeddings_per_code):
43 unique_doc_idcs.append(doc_idcs.unique())
44 doc_ivf_lengths.append(unique_doc_idcs[-1].shape[0])
45 self.code_to_doc_ivf = PackedTensor(torch.cat(unique_doc_idcs), doc_ivf_lengths)
46
47 # doc_idx to code_idcs mapping
48 sorted_doc_idcs, doc_idx_to_code_idx = self.code_to_doc_ivf.packed_tensor.sort()
49 code_idcs = torch.arange(self.residual_codec.num_centroids).repeat_interleave(
50 torch.tensor(self.code_to_doc_ivf.lengths)
51 )[doc_idx_to_code_idx]
52 num_codes_per_doc = torch.bincount(sorted_doc_idcs, minlength=self.num_docs)
53 self.doc_to_code_ivf = PackedTensor(code_idcs, num_codes_per_doc.tolist())
54
55 self.search_config: PlaidSearchConfig
56
57 @property
58 def num_embeddings(self) -> int:
59 return int(self.cumulative_doc_lengths[-1].item())
60
61 def candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, PackedTensor]:
62 # grab top `n_cells` neighbor cells for all query embeddings
63 # `num_queries x query_length x num_centroids`
64 scores = (
65 query_embeddings.embeddings.to(self.residual_codec.centroids)
66 @ self.residual_codec.centroids.transpose(0, 1)[None]
67 )
68 scores = scores.masked_fill(~query_embeddings.scoring_mask[..., None], 0)
69 _, codes = torch.topk(scores, self.search_config.n_cells, dim=-1, sorted=False)
70 packed_codes = codes[query_embeddings.scoring_mask].view(-1)
71 code_lengths = (query_embeddings.scoring_mask.sum(-1) * self.search_config.n_cells).tolist()
72
73 # grab document idcs for all cells
74 packed_doc_idcs = self.code_to_doc_ivf.lookup(packed_codes, code_lengths, unique=True)
75 return scores, packed_doc_idcs
76
77 def filter_candidates(
78 self, centroid_scores: torch.Tensor, doc_idcs: PackedTensor, threshold: float | None, k: int
79 ) -> PackedTensor:
80 num_query_vecs = centroid_scores.shape[1]
81 num_centroids = centroid_scores.shape[-1]
82
83 # repeat query centroid scores for each document
84 # `num_docs x num_query_vecs x num_centroids + 1`
85 # NOTE we pad values such that the codes with -1 padding index 0 values
86 expanded_centroid_scores = torch.nn.functional.pad(
87 centroid_scores.repeat_interleave(torch.tensor(doc_idcs.lengths), dim=0), (0, 1)
88 )
89
90 # grab codes for each document
91 code_idcs = self.doc_to_code_ivf.lookup(doc_idcs.packed_tensor, 1)
92 # `num_docs x max_num_codes_per_doc`
93 padded_codes = code_idcs.to_padded_tensor(pad_value=num_centroids)
94 mask = padded_codes != num_centroids
95 # `num_docs x max_num_query_vecs x max_num_codes_per_doc`
96 padded_codes = padded_codes[:, None].expand(-1, num_query_vecs, -1)
97
98 # apply pruning threshold
99 if threshold is not None and threshold:
100 expanded_centroid_scores = expanded_centroid_scores.masked_fill(
101 expanded_centroid_scores.amax(1, keepdim=True) < threshold, 0
102 )
103
104 # NOTE this is colbert scoring, but instead of using the doc embeddings we use the centroid scores
105 # expanded_centroid_scores: `num_docs x max_num_query_vecs x num_centroids + 1 `
106 # padded_codes: `num_docs x max_num_query_vecs x max_num_codes_per_doc`
107 # approx_similarity: `num_docs x max_num_query_vecs x max_num_codes_per_doc`
108 approx_similarity = torch.gather(input=expanded_centroid_scores, dim=-1, index=padded_codes)
109 approx_scores = self.module.scoring_function.aggregate_similarity(
110 approx_similarity, query_scoring_mask=None, doc_scoring_mask=mask[:, None]
111 )
112
113 filtered_doc_idcs = []
114 lengths = []
115 iterator = zip(doc_idcs.packed_tensor.split(doc_idcs.lengths), approx_scores.split(doc_idcs.lengths))
116 for doc_idcs, doc_scores in iterator:
117 if doc_scores.shape[0] <= k:
118 filtered_doc_idcs.append(doc_idcs)
119 else:
120 filtered_doc_idcs.append(doc_idcs[torch.topk(doc_scores, k, sorted=False)])
121 lengths.append(filtered_doc_idcs[-1].shape[0])
122
123 packed_filtered_doc_idcs = PackedTensor(torch.cat(filtered_doc_idcs), lengths)
124
125 return packed_filtered_doc_idcs
126
127 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
128 query_embeddings = query_embeddings.to(self.device)
129 centroid_scores, doc_idcs = self.candidate_retrieval(query_embeddings)
130 # NOTE no idea why we do two filter steps (the first with a threshold, the second without)
131 # filter step 1
132 filtered_doc_idcs = self.filter_candidates(
133 centroid_scores, doc_idcs, self.search_config.centroid_score_threshold, self.search_config.candidate_k
134 )
135 # filter step 2
136 filtered_doc_idcs = self.filter_candidates(
137 centroid_scores, filtered_doc_idcs, None, self.search_config.candidate_k // 4
138 )
139
140 # gather/decompress document embeddings
141 doc_embedding_codes = self.packed_codes.lookup(filtered_doc_idcs.packed_tensor, 1)
142 doc_embedding_residuals = self.packed_residuals.lookup(filtered_doc_idcs.packed_tensor, 1)
143 doc_embeddings = self.residual_codec.decompress(doc_embedding_codes, doc_embedding_residuals)
144 padded_doc_embeddings = doc_embeddings.to_padded_tensor()
145 doc_scoring_mask = padded_doc_embeddings[..., 0] != 0
146
147 # compute scores
148 num_docs = filtered_doc_idcs.lengths
149 doc_scores = self.module.scoring_function.forward(
150 query_embeddings,
151 BiEncoderEmbedding(padded_doc_embeddings, doc_scoring_mask, None),
152 num_docs,
153 )
154 return doc_scores, filtered_doc_idcs.packed_tensor, num_docs
155
156
[docs]
157class PlaidSearchConfig(SearchConfig):
158
159 search_class: Type[Searcher] = PlaidSearcher
160
[docs]
161 def __init__(
162 self,
163 k: int,
164 candidate_k: int | None = None,
165 n_cells: int | None = None,
166 centroid_score_threshold: float | None = None,
167 ) -> None:
168 # https://github.com/stanford-futuredata/ColBERT/blob/7067ef598b5011edaa1f4a731a2c269dbac864e4/colbert/searcher.py#L106
169 super().__init__(k)
170 if candidate_k is None:
171 if k <= 10:
172 candidate_k = 256
173 elif k <= 100:
174 candidate_k = 1_024
175 else:
176 candidate_k = max(k * 4, 4_096)
177 self.candidate_k = candidate_k
178 if n_cells is None:
179 if k <= 10:
180 n_cells = 1
181 elif k <= 100:
182 n_cells = 2
183 else:
184 n_cells = 4
185 self.n_cells = n_cells
186 if centroid_score_threshold is None:
187 if k <= 10:
188 centroid_score_threshold = 0.5
189 elif k <= 100:
190 centroid_score_threshold = 0.45
191 else:
192 centroid_score_threshold = 0.4
193 self.centroid_score_threshold = centroid_score_threshold