Source code for lightning_ir.retrieve.plaid.plaid_indexer

  1import warnings
  2from array import array
  3from pathlib import Path
  4
  5import torch
  6
  7from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
  8from ...data import IndexBatch
  9from ..base import IndexConfig, Indexer
 10from .residual_codec import ResidualCodec
 11
 12
[docs] 13class PlaidIndexer(Indexer): 14
[docs] 15 def __init__( 16 self, 17 index_dir: Path, 18 index_config: "PlaidIndexConfig", 19 bi_encoder_config: BiEncoderConfig, 20 verbose: bool = False, 21 ) -> None: 22 super().__init__(index_dir, index_config, bi_encoder_config, verbose) 23 try: 24 import faiss 25 except ImportError: 26 raise ImportError("faiss is required for PlaidIndexer") 27 28 self.index_config: PlaidIndexConfig 29 30 self._train_embeddings: torch.Tensor | None = torch.full( 31 (self.index_config.num_train_embeddings, self.bi_encoder_config.embedding_dim), 32 torch.nan, 33 dtype=torch.float32, 34 ) 35 self.residual_codec: ResidualCodec | None = None 36 self.codes = array("l") 37 self.residuals = array("B")
38 39 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: 40 doc_embeddings = output.doc_embeddings 41 if doc_embeddings is None: 42 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 43 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1) 44 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask] 45 doc_ids = index_batch.doc_ids 46 embeddings = self.process_embeddings(embeddings) 47 48 if embeddings.shape[0]: 49 if self.residual_codec is None: 50 raise ValueError("Residual codec not trained") 51 codes, residuals = self.residual_codec.compress(embeddings) 52 self.codes.extend(codes.numpy(force=True)) 53 self.residuals.extend(residuals.view(-1).numpy(force=True)) 54 55 self.num_embeddings += embeddings.shape[0] 56 self.num_docs += len(doc_ids) 57 58 self.doc_lengths.extend(doc_lengths.cpu().tolist()) 59 self.doc_ids.extend(doc_ids) 60 61 def process_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 62 embeddings = self._grab_train_embeddings(embeddings) 63 self._train() 64 return embeddings 65 66 def _grab_train_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor: 67 if self._train_embeddings is not None: 68 # save training embeddings until num_train_embeddings is reached 69 # if num_train_embeddings overflows, save the remaining embeddings 70 start = self.num_embeddings 71 end = min(self.index_config.num_train_embeddings, start + embeddings.shape[0]) 72 length = end - start 73 self._train_embeddings[start:end] = embeddings[:length] 74 self.num_embeddings += length 75 embeddings = embeddings[length:] 76 return embeddings 77 78 def _train(self, force: bool = False) -> None: 79 if self._train_embeddings is None: 80 return 81 if not force and self.num_embeddings < self.index_config.num_train_embeddings: 82 return 83 84 if torch.isnan(self._train_embeddings).any(): 85 warnings.warn("Corpus contains less tokens/documents than num_train_embeddings. Removing NaN embeddings.") 86 self._train_embeddings = self._train_embeddings[~torch.isnan(self._train_embeddings).any(dim=1)] 87 88 self.residual_codec = ResidualCodec.train(self.index_config, self._train_embeddings, self.verbose) 89 codes, residuals = self.residual_codec.compress(self._train_embeddings) 90 self.codes.extend(codes.numpy(force=True)) 91 self.residuals.extend(residuals.view(-1).numpy(force=True)) 92 93 self._train_embeddings = None 94 95 def save(self) -> None: 96 if self.residual_codec is None: 97 self._train(force=True) 98 if self.residual_codec is None: 99 raise ValueError("No residual codec to save") 100 super().save() 101 102 codes = torch.frombuffer(self.codes, dtype=torch.long) 103 residuals = torch.frombuffer(self.residuals, dtype=torch.uint8) 104 torch.save(codes, self.index_dir / "codes.pt") 105 torch.save(residuals, self.index_dir / "residuals.pt") 106 self.residual_codec.save(self.index_dir)
107 108
[docs] 109class PlaidIndexConfig(IndexConfig): 110 indexer_class = PlaidIndexer 111
[docs] 112 def __init__( 113 self, num_centroids: int, num_train_embeddings: int, k_means_iters: int = 4, n_bits: int = 2, seed: int = 42 114 ) -> None: 115 super().__init__() 116 self.num_centroids = num_centroids 117 self.num_train_embeddings = num_train_embeddings 118 self.k_means_iters = k_means_iters 119 self.n_bits = n_bits 120 self.seed = seed