Source code for lightning_ir.retrieve.sparse.sparse_indexer
1import array
2from pathlib import Path
3
4import torch
5
6from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
7from ...data import IndexBatch
8from ..base import IndexConfig, Indexer
9
10
[docs]
11class SparseIndexer(Indexer):
[docs]
12 def __init__(
13 self,
14 index_dir: Path,
15 index_config: "SparseIndexConfig",
16 bi_encoder_config: BiEncoderConfig,
17 verbose: bool = False,
18 ) -> None:
19 super().__init__(index_dir, index_config, bi_encoder_config, verbose)
20 self.crow_indices = array.array("L")
21 self.crow_indices.append(0)
22 self.col_indices = array.array("L")
23 self.values = array.array("f")
24
25 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None:
26 doc_embeddings = output.doc_embeddings
27 if doc_embeddings is None:
28 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
29
30 doc_lengths = doc_embeddings.scoring_mask.sum(dim=1)
31 embeddings = doc_embeddings.embeddings[doc_embeddings.scoring_mask]
32 num_docs = len(index_batch.doc_ids)
33 self.doc_ids.extend(index_batch.doc_ids)
34
35 token_idcs, dim_idcs = torch.nonzero(embeddings, as_tuple=True)
36 crow_indices = token_idcs.bincount().cumsum(0) + self.crow_indices[-1]
37 values = embeddings[token_idcs, dim_idcs]
38 self.crow_indices.extend(crow_indices.cpu().tolist())
39 self.col_indices.extend(dim_idcs.cpu().tolist())
40 self.values.extend(values.cpu().tolist())
41
42 self.doc_lengths.extend(doc_lengths.cpu().tolist())
43 self.num_embeddings += embeddings.shape[0]
44 self.num_docs += num_docs
45
46 def to_gpu(self) -> None:
47 pass
48
49 def to_cpu(self) -> None:
50 pass
51
52 def save(self) -> None:
53 super().save()
54 index = torch.sparse_csr_tensor(
55 torch.frombuffer(self.crow_indices, dtype=torch.int64),
56 torch.frombuffer(self.col_indices, dtype=torch.int64),
57 torch.frombuffer(self.values, dtype=torch.float32),
58 torch.Size([self.num_embeddings, self.bi_encoder_config.embedding_dim]),
59 )
60 torch.save(index, self.index_dir / "index.pt")
61
62
[docs]
63class SparseIndexConfig(IndexConfig):
64 indexer_class = SparseIndexer