Source code for lightning_ir.retrieve.plaid.plaid_indexer
1import warnings
2from array import array
3from pathlib import Path
4
5import torch
6
7from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
8from ...data import IndexBatch
9from ..base import IndexConfig, Indexer
10from .residual_codec import ResidualCodec
11
12
[docs]
13class PlaidIndexer(Indexer):
14
[docs]
15 def __init__(
16 self,
17 index_dir: Path,
18 index_config: "PlaidIndexConfig",
19 bi_encoder_config: BiEncoderConfig,
20 verbose: bool = False,
21 ) -> None:
22 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
23 try:
24 import faiss
25 except ImportError:
26 raise ImportError("faiss is required for PlaidIndexer")
27
28 self.index_config: PlaidIndexConfig
29
30 self._train_embeddings: torch.Tensor | None = torch.full(
31 (self.index_config.num_train_embeddings, self.bi_encoder_config.embedding_dim),
32 torch.nan,
33 dtype=torch.float32,
34 )
35 self.residual_codec: ResidualCodec | None = None
36 self.codes = array("l")
37 self.residuals = array("B")
38
39 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
40 doc_embeddings = output.doc_embeddings
41 if doc_embeddings is None:
42 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
43 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
44 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
45 doc_ids = index_batch.doc_ids
46 embeddings = self.process_embeddings(embeddings)
47
48 if embeddings.shape[0]:
49 if self.residual_codec is None:
50 raise ValueError("Residual codec not trained")
51 codes, residuals = self.residual_codec.compress(embeddings)
52 self.codes.extend(codes.numpy(force=True))
53 self.residuals.extend(residuals.view(-1).numpy(force=True))
54
55 self.num_embeddings += embeddings.shape[0]
56 self.num_docs += len(doc_ids)
57
58 self.doc_lengths.extend(doc_lengths.cpu().tolist())
59 self.doc_ids.extend(doc_ids)
60
61 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
62 embeddings = self._grab_train_embeddings(embeddings)
63 self._train()
64 return embeddings
65
66 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
67 if self._train_embeddings is not None:
68 # save training embeddings until num_train_embeddings is reached
69 # if num_train_embeddings overflows, save the remaining embeddings
70 start = self.num_embeddings
71 end = min(self.index_config.num_train_embeddings, start + embeddings.shape[0])
72 length = end - start
73 self._train_embeddings[start:end] = embeddings[:length]
74 self.num_embeddings += length
75 embeddings = embeddings[length:]
76 return embeddings
77
78 def _train(self, force: bool = False) -> None:
79 if self._train_embeddings is None:
80 return
81 if not force and self.num_embeddings < self.index_config.num_train_embeddings:
82 return
83
84 if torch.isnan(self._train_embeddings).any():
85 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.")
86 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)]
87
88 self.residual_codec = ResidualCodec.train(self.index_config, self._train_embeddings, self.verbose)
89 codes, residuals = self.residual_codec.compress(self._train_embeddings)
90 self.codes.extend(codes.numpy(force=True))
91 self.residuals.extend(residuals.view(-1).numpy(force=True))
92
93 self._train_embeddings = None
94
95 def save(self) -> None:
96 if self.residual_codec is None:
97 self._train(force=True)
98 if self.residual_codec is None:
99 raise ValueError("No residual codec to save")
100 super().save()
101
102 codes = torch.frombuffer(self.codes, dtype=torch.long)
103 residuals = torch.frombuffer(self.residuals, dtype=torch.uint8)
104 torch.save(codes, self.index_dir / "codes.pt")
105 torch.save(residuals, self.index_dir / "residuals.pt")
106 self.residual_codec.save(self.index_dir)
107
108
[docs]
109class PlaidIndexConfig(IndexConfig):
110 indexer_class = PlaidIndexer
111
[docs]
112 def __init__(
113 self, num_centroids: int, num_train_embeddings: int, k_means_iters: int = 4, n_bits: int = 2, seed: int = 42
114 ) -> None:
115 super().__init__()
116 self.num_centroids = num_centroids
117 self.num_train_embeddings = num_train_embeddings
118 self.k_means_iters = k_means_iters
119 self.n_bits = n_bits
120 self.seed = seed