1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, Literal, Tuple
5
6import torch
7
8from ..base import SearchConfig, Searcher
9from .sparse_indexer import SparseIndexConfig
10
11if TYPE_CHECKING:
12 from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule
13
14
[docs]
15class SparseIndex:
[docs]
16 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None:
17 self.index = torch.load(index_dir / "index.pt")
18 self.config = SparseIndexConfig.from_pretrained(index_dir)
19 if similarity_function == "dot":
20 self.similarity_function = self.dot_similarity
21 elif similarity_function == "cosine":
22 self.similarity_function = self.cosine_similarity
23 else:
24 raise ValueError("Unknown similarity function")
25 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
26
27 def score(self, embeddings: torch.Tensor) -> torch.Tensor:
28 embeddings = embeddings.to(self.device)
29 similarity = self.similarity_function(embeddings, self.index).to_dense()
30 return similarity
31
32 @property
33 def num_embeddings(self) -> int:
34 return self.index.shape[0]
35
36 def cosine_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
37 dot_product = self.dot_similarity(x, y)
38 dot_product = dot_product / (torch.norm(x, dim=-1) * torch.norm(y, dim=-1))
39 return -1 * torch.cdist(x, y).squeeze(-2)
40
41 def dot_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
42 return torch.matmul(y, x.T).T
43
44 def to_gpu(self) -> None:
45 self.index = self.index.to(self.device)
46
47
[docs]
48class SparseSearcher(Searcher):
[docs]
49 def __init__(
50 self,
51 index_dir: Path,
52 search_config: SparseSearchConfig,
53 module: BiEncoderModule,
54 use_gpu: bool = True,
55 ) -> None:
56 self.search_config: SparseSearchConfig
57 self.index = SparseIndex(index_dir, module.config.similarity_function, use_gpu)
58 super().__init__(index_dir, search_config, module, use_gpu)
59 self.doc_token_idcs = (
60 torch.arange(self.doc_lengths.shape[0]).to(self.doc_lengths).repeat_interleave(self.doc_lengths)
61 )
62 self.use_gpu = use_gpu
63 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
64
65 @property
66 def doc_is_single_vector(self) -> bool:
67 return self.cumulative_doc_lengths[-1].item() == self.cumulative_doc_lengths.shape[0]
68
69 def to_gpu(self) -> None:
70 super().to_gpu()
71 self.index.to_gpu()
72
73 @property
74 def num_embeddings(self) -> int:
75 return self.index.num_embeddings
76
77 def _search(self, query_embeddings: BiEncoderEmbedding) -> Tuple[torch.Tensor, None, None]:
78 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
79 query_lengths = query_embeddings.scoring_mask.sum(-1)
80 scores = self.index.score(embeddings)
81
82 # aggregate doc token scores
83 if not self.doc_is_single_vector:
84 scores = torch.scatter_reduce(
85 torch.zeros(scores.shape[0], self.num_docs, device=scores.device),
86 1,
87 self.doc_token_idcs[None].long().expand_as(scores),
88 scores,
89 "amax",
90 )
91
92 # aggregate query token scores
93 query_is_single_vector = (query_lengths == 1).all()
94 if not query_is_single_vector:
95 query_token_idcs = torch.arange(query_lengths.shape[0]).to(query_lengths).repeat_interleave(query_lengths)
96 scores = torch.scatter_reduce(
97 torch.zeros(query_lengths.shape[0], self.num_docs, device=scores.device),
98 0,
99 query_token_idcs[:, None].expand_as(scores),
100 scores,
101 self.module.config.query_aggregation_function,
102 )
103 scores = scores.reshape(-1)
104 return scores, None, None
105
106
[docs]
107class SparseSearchConfig(SearchConfig):
108 search_class = SparseSearcher