Source code for lightning_ir.retrieve.plaid.residual_codec

  1from __future__ import annotations
  2
  3import pathlib
  4from itertools import product
  5from pathlib import Path
  6from typing import TYPE_CHECKING, Tuple
  7
  8import numpy as np
  9import torch
 10from torch.utils.cpp_extension import load
 11
 12from .packed_tensor import PackedTensor
 13
 14if TYPE_CHECKING:
 15    from .plaid_indexer import PlaidIndexConfig
 16
 17
[docs] 18class ResidualCodec: 19
[docs] 20 def __init__( 21 self, 22 index_config: PlaidIndexConfig, 23 centroids: torch.Tensor, 24 bucket_cutoffs: torch.Tensor, 25 bucket_weights: torch.Tensor, 26 verbose: bool = False, 27 ) -> None: 28 self.index_config = index_config 29 self.verbose = verbose 30 31 self.centroids = centroids 32 self.bucket_cutoffs = bucket_cutoffs 33 self.bucket_weights = bucket_weights 34 35 self.arange_bits = torch.arange(0, self.index_config.n_bits, dtype=torch.uint8) 36 self.reversed_bit_map = self._compute_reverse_bit_map() 37 keys_per_byte = 8 // self.index_config.n_bits 38 self.decompression_lookup_table = torch.tensor( 39 list(product(list(range(len(self.bucket_weights))), repeat=keys_per_byte)) 40 ).to(torch.uint8) 41 42 self.residual_dim = max(1, centroids.shape[-1] // 8 * index_config.n_bits) 43 44 self._packbits_cpp = None
45 46 def __repr__(self) -> str: 47 return f"{self.__class__.__name__}(dim={self.dim}, num_centroids={self.num_centroids})" 48 49 def __str__(self) -> str: 50 return self.__repr__() 51 52 @property 53 def dim(self) -> int: 54 return self.centroids.shape[-1] 55 56 @property 57 def num_centroids(self) -> int: 58 return self.centroids.shape[0] 59 60 @classmethod 61 def train( 62 cls, index_config: PlaidIndexConfig, train_embeddings: torch.Tensor, verbose: bool = False 63 ) -> "ResidualCodec": 64 train_embeddings = train_embeddings[torch.randperm(train_embeddings.shape[0])] 65 num_hold_out_embeddings = int(min(0.05 * train_embeddings.shape[0], 2**15)) 66 train_embeddings, holdout_embeddings = train_embeddings.split( 67 [train_embeddings.shape[0] - num_hold_out_embeddings, num_hold_out_embeddings] 68 ) 69 70 centroids = cls._train_kmeans(train_embeddings, index_config, verbose) 71 bucket_cutoffs, bucket_weights = cls._compute_buckets(centroids, holdout_embeddings, index_config) 72 73 return cls(index_config, centroids, bucket_cutoffs, bucket_weights, verbose) 74 75 @staticmethod 76 def _train_kmeans(embeddings: torch.Tensor, index_config: PlaidIndexConfig, verbose: bool = False) -> torch.Tensor: 77 import faiss 78 79 kmeans = faiss.Kmeans( 80 embeddings.shape[-1], 81 index_config.num_centroids, 82 niter=index_config.k_means_iters, 83 gpu=torch.cuda.is_available(), 84 verbose=verbose, 85 seed=index_config.seed, 86 ) 87 # TODO why normalize? 88 kmeans.train(embeddings.numpy()) 89 return torch.nn.functional.normalize(torch.from_numpy(kmeans.centroids), dim=-1) 90 91 def _packbits(self, residuals: torch.Tensor) -> torch.Tensor: 92 if residuals.device == torch.device("cuda"): 93 raise NotImplementedError("CUDA not supported for packbits") 94 residuals_packed = torch.from_numpy(np.packbits(np.asarray(residuals.contiguous().flatten()))) 95 return residuals_packed 96 97 @staticmethod 98 def _compute_buckets( 99 centroids: torch.Tensor, holdout_embeddings: torch.Tensor, index_config: PlaidIndexConfig 100 ) -> Tuple[torch.Tensor, torch.Tensor]: 101 holdout_embeddings_codes = ResidualCodec._compress_into_codes(centroids, holdout_embeddings) 102 holdout_embeddings_centroids = centroids[holdout_embeddings_codes] 103 104 holdout_residual = holdout_embeddings - holdout_embeddings_centroids 105 avg_residual = holdout_residual.abs().mean(dim=0) 106 107 num_options = 2**index_config.n_bits 108 quantiles = torch.arange(0, num_options, device=avg_residual.device) * (1 / num_options) 109 bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options) 110 111 bucket_cutoffs = holdout_residual.float().quantile(bucket_cutoffs_quantiles) 112 bucket_weights = holdout_residual.float().quantile(bucket_weights_quantiles) 113 return bucket_cutoffs, bucket_weights 114 115 def _compute_reverse_bit_map(self) -> torch.Tensor: 116 # We reverse the residual bits because arange_bits as 117 # currently constructed produces results with the reverse 118 # of the expected endianness 119 120 reversed_bit_map = [] 121 mask = (1 << self.index_config.n_bits) - 1 122 for i in range(256): 123 # The reversed byte 124 z = 0 125 for j in range(8, 0, -self.index_config.n_bits): 126 # Extract a subsequence of length n bits 127 x = (i >> (j - self.index_config.n_bits)) & mask 128 129 # Reverse the endianness of each bit subsequence (e.g. 10 -> 01) 130 y = 0 131 for k in range(self.index_config.n_bits - 1, -1, -1): 132 y += ((x >> (self.index_config.n_bits - k - 1)) & 1) * (2**k) 133 134 # Set the corresponding bits in the output byte 135 z |= y 136 if j > self.index_config.n_bits: 137 z <<= self.index_config.n_bits 138 reversed_bit_map.append(z) 139 return torch.tensor(reversed_bit_map).to(torch.uint8) 140 141 @classmethod 142 def try_load_torch_extensions(cls, use_gpu): 143 if hasattr(cls, "loaded_extensions") or not use_gpu: 144 return 145 146 decompress_residuals_cpp = load( 147 name="decompress_residuals_cpp", 148 sources=[ 149 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cpp"), 150 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cu"), 151 ], 152 ) 153 cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp 154 155 cls.loaded_extensions = True 156 157 @classmethod 158 def from_pretrained(cls, index_config: PlaidIndexConfig, index_dir: Path) -> "ResidualCodec": 159 centroids_path = index_dir / "centroids.pt" 160 buckets_path = index_dir / "buckets.pt" 161 162 centroids = torch.load(centroids_path, map_location="cpu") 163 bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location="cpu") 164 165 return cls( 166 index_config=index_config, 167 centroids=centroids, 168 bucket_cutoffs=bucket_cutoffs, 169 bucket_weights=bucket_weights, 170 ) 171 172 def save(self, index_dir: Path): 173 index_dir.mkdir(parents=True, exist_ok=True) 174 centroids_path = index_dir / "centroids.pt" 175 buckets_path = index_dir / "buckets.pt" 176 177 torch.save(self.centroids.half(), centroids_path) 178 torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path) 179 180 @staticmethod 181 def _compress_into_codes(centroids: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: 182 codes = [] 183 batch_size = 2**29 // centroids.shape[0] 184 for batch in embeddings.split(batch_size): 185 indices = (centroids @ batch.transpose(-1, -2)).argmax(dim=0) 186 codes.append(indices) 187 return torch.cat(codes) 188 189 def compress_into_codes(self, embeddings: torch.Tensor) -> torch.Tensor: 190 return self._compress_into_codes(self.centroids, embeddings) 191 192 def compress(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 193 codes = self.compress_into_codes(embeddings) 194 centroids = self.centroids[codes] 195 residuals = self.binarize(embeddings - centroids) 196 return codes, residuals 197 198 def binarize(self, residuals: torch.Tensor) -> torch.Tensor: 199 buckets = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8) 200 buckets_expanded = buckets.unsqueeze(-1).expand(*buckets.size(), self.index_config.n_bits) 201 bucket_bits = buckets_expanded >> self.arange_bits # divide by 2^bit for each bit position 202 bucket_binary = bucket_bits & 1 # apply mod 2 to binarize 203 204 residuals_packed = self._packbits(bucket_binary) 205 residuals_packed = residuals_packed.reshape(residuals.size(0), max(1, self.dim // 8 * self.index_config.n_bits)) 206 207 return residuals_packed 208 209 def decompress(self, codes: PackedTensor, compressed_residuals: PackedTensor) -> PackedTensor: 210 211 centroids = self.centroids[codes.packed_tensor] 212 residuals = self.reversed_bit_map[compressed_residuals.packed_tensor.long().view(-1)].view_as( 213 compressed_residuals.packed_tensor 214 ) 215 residuals = self.decompression_lookup_table[residuals.long()] 216 residuals = residuals.view(residuals.shape[0], -1) 217 residuals = self.bucket_weights[residuals.long()] 218 embeddings = centroids + residuals 219 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) 220 221 return PackedTensor(embeddings, codes.lengths)