1from __future__ import annotations
2
3import pathlib
4from itertools import product
5from pathlib import Path
6from typing import TYPE_CHECKING, Tuple
7
8import numpy as np
9import torch
10from torch.utils.cpp_extension import load
11
12from .packed_tensor import PackedTensor
13
14if TYPE_CHECKING:
15 from .plaid_indexer import PlaidIndexConfig
16
17
[docs]
18class ResidualCodec:
19
[docs]
20 def __init__(
21 self,
22 index_config: PlaidIndexConfig,
23 centroids: torch.Tensor,
24 bucket_cutoffs: torch.Tensor,
25 bucket_weights: torch.Tensor,
26 verbose: bool = False,
27 ) -> None:
28 self.index_config = index_config
29 self.verbose = verbose
30
31 self.centroids = centroids
32 self.bucket_cutoffs = bucket_cutoffs
33 self.bucket_weights = bucket_weights
34
35 self.arange_bits = torch.arange(0, self.index_config.n_bits, dtype=torch.uint8)
36 self.reversed_bit_map = self._compute_reverse_bit_map()
37 keys_per_byte = 8 // self.index_config.n_bits
38 self.decompression_lookup_table = torch.tensor(
39 list(product(list(range(len(self.bucket_weights))), repeat=keys_per_byte))
40 ).to(torch.uint8)
41
42 self.residual_dim = max(1, centroids.shape[-1] // 8 * index_config.n_bits)
43
44 self._packbits_cpp = None
45
46 def __repr__(self) -> str:
47 return f"{self.__class__.__name__}(dim={self.dim}, num_centroids={self.num_centroids})"
48
49 def __str__(self) -> str:
50 return self.__repr__()
51
52 @property
53 def dim(self) -> int:
54 return self.centroids.shape[-1]
55
56 @property
57 def num_centroids(self) -> int:
58 return self.centroids.shape[0]
59
60 @classmethod
61 def train(
62 cls, index_config: PlaidIndexConfig, train_embeddings: torch.Tensor, verbose: bool = False
63 ) -> "ResidualCodec":
64 train_embeddings = train_embeddings[torch.randperm(train_embeddings.shape[0])]
65 num_hold_out_embeddings = int(min(0.05 * train_embeddings.shape[0], 2**15))
66 train_embeddings, holdout_embeddings = train_embeddings.split(
67 [train_embeddings.shape[0] - num_hold_out_embeddings, num_hold_out_embeddings]
68 )
69
70 centroids = cls._train_kmeans(train_embeddings, index_config, verbose)
71 bucket_cutoffs, bucket_weights = cls._compute_buckets(centroids, holdout_embeddings, index_config)
72
73 return cls(index_config, centroids, bucket_cutoffs, bucket_weights, verbose)
74
75 @staticmethod
76 def _train_kmeans(embeddings: torch.Tensor, index_config: PlaidIndexConfig, verbose: bool = False) -> torch.Tensor:
77 import faiss
78
79 kmeans = faiss.Kmeans(
80 embeddings.shape[-1],
81 index_config.num_centroids,
82 niter=index_config.k_means_iters,
83 gpu=torch.cuda.is_available(),
84 verbose=verbose,
85 seed=index_config.seed,
86 )
87 # TODO why normalize?
88 kmeans.train(embeddings.numpy())
89 return torch.nn.functional.normalize(torch.from_numpy(kmeans.centroids), dim=-1)
90
91 def _packbits(self, residuals: torch.Tensor) -> torch.Tensor:
92 if residuals.device == torch.device("cuda"):
93 raise NotImplementedError("CUDA not supported for packbits")
94 residuals_packed = torch.from_numpy(np.packbits(np.asarray(residuals.contiguous().flatten())))
95 return residuals_packed
96
97 @staticmethod
98 def _compute_buckets(
99 centroids: torch.Tensor, holdout_embeddings: torch.Tensor, index_config: PlaidIndexConfig
100 ) -> Tuple[torch.Tensor, torch.Tensor]:
101 holdout_embeddings_codes = ResidualCodec._compress_into_codes(centroids, holdout_embeddings)
102 holdout_embeddings_centroids = centroids[holdout_embeddings_codes]
103
104 holdout_residual = holdout_embeddings - holdout_embeddings_centroids
105 avg_residual = holdout_residual.abs().mean(dim=0)
106
107 num_options = 2**index_config.n_bits
108 quantiles = torch.arange(0, num_options, device=avg_residual.device) * (1 / num_options)
109 bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options)
110
111 bucket_cutoffs = holdout_residual.float().quantile(bucket_cutoffs_quantiles)
112 bucket_weights = holdout_residual.float().quantile(bucket_weights_quantiles)
113 return bucket_cutoffs, bucket_weights
114
115 def _compute_reverse_bit_map(self) -> torch.Tensor:
116 # We reverse the residual bits because arange_bits as
117 # currently constructed produces results with the reverse
118 # of the expected endianness
119
120 reversed_bit_map = []
121 mask = (1 << self.index_config.n_bits) - 1
122 for i in range(256):
123 # The reversed byte
124 z = 0
125 for j in range(8, 0, -self.index_config.n_bits):
126 # Extract a subsequence of length n bits
127 x = (i >> (j - self.index_config.n_bits)) & mask
128
129 # Reverse the endianness of each bit subsequence (e.g. 10 -> 01)
130 y = 0
131 for k in range(self.index_config.n_bits - 1, -1, -1):
132 y += ((x >> (self.index_config.n_bits - k - 1)) & 1) * (2**k)
133
134 # Set the corresponding bits in the output byte
135 z |= y
136 if j > self.index_config.n_bits:
137 z <<= self.index_config.n_bits
138 reversed_bit_map.append(z)
139 return torch.tensor(reversed_bit_map).to(torch.uint8)
140
141 @classmethod
142 def try_load_torch_extensions(cls, use_gpu):
143 if hasattr(cls, "loaded_extensions") or not use_gpu:
144 return
145
146 decompress_residuals_cpp = load(
147 name="decompress_residuals_cpp",
148 sources=[
149 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cpp"),
150 str(pathlib.Path(__file__).parent.resolve() / "csrc" / "decompress_residuals.cu"),
151 ],
152 )
153 cls.decompress_residuals = decompress_residuals_cpp.decompress_residuals_cpp
154
155 cls.loaded_extensions = True
156
157 @classmethod
158 def from_pretrained(cls, index_config: PlaidIndexConfig, index_dir: Path) -> "ResidualCodec":
159 centroids_path = index_dir / "centroids.pt"
160 buckets_path = index_dir / "buckets.pt"
161
162 centroids = torch.load(centroids_path, map_location="cpu")
163 bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location="cpu")
164
165 return cls(
166 index_config=index_config,
167 centroids=centroids,
168 bucket_cutoffs=bucket_cutoffs,
169 bucket_weights=bucket_weights,
170 )
171
172 def save(self, index_dir: Path):
173 index_dir.mkdir(parents=True, exist_ok=True)
174 centroids_path = index_dir / "centroids.pt"
175 buckets_path = index_dir / "buckets.pt"
176
177 torch.save(self.centroids.half(), centroids_path)
178 torch.save((self.bucket_cutoffs, self.bucket_weights), buckets_path)
179
180 @staticmethod
181 def _compress_into_codes(centroids: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor:
182 codes = []
183 batch_size = 2**29 // centroids.shape[0]
184 for batch in embeddings.split(batch_size):
185 indices = (centroids @ batch.transpose(-1, -2)).argmax(dim=0)
186 codes.append(indices)
187 return torch.cat(codes)
188
189 def compress_into_codes(self, embeddings: torch.Tensor) -> torch.Tensor:
190 return self._compress_into_codes(self.centroids, embeddings)
191
192 def compress(self, embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
193 codes = self.compress_into_codes(embeddings)
194 centroids = self.centroids[codes]
195 residuals = self.binarize(embeddings - centroids)
196 return codes, residuals
197
198 def binarize(self, residuals: torch.Tensor) -> torch.Tensor:
199 buckets = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8)
200 buckets_expanded = buckets.unsqueeze(-1).expand(*buckets.size(), self.index_config.n_bits)
201 bucket_bits = buckets_expanded >> self.arange_bits # divide by 2^bit for each bit position
202 bucket_binary = bucket_bits & 1 # apply mod 2 to binarize
203
204 residuals_packed = self._packbits(bucket_binary)
205 residuals_packed = residuals_packed.reshape(residuals.size(0), max(1, self.dim // 8 * self.index_config.n_bits))
206
207 return residuals_packed
208
209 def decompress(self, codes: PackedTensor, compressed_residuals: PackedTensor) -> PackedTensor:
210
211 centroids = self.centroids[codes.packed_tensor]
212 residuals = self.reversed_bit_map[compressed_residuals.packed_tensor.long().view(-1)].view_as(
213 compressed_residuals.packed_tensor
214 )
215 residuals = self.decompression_lookup_table[residuals.long()]
216 residuals = residuals.view(residuals.shape[0], -1)
217 residuals = self.bucket_weights[residuals.long()]
218 embeddings = centroids + residuals
219 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
220
221 return PackedTensor(embeddings, codes.lengths)