1"""
2Module module for bi-encoder models.
3
4This module defines the Lightning IR module class used to implement bi-encoder models.
5"""
6
7from pathlib import Path
8from typing import List, Sequence, Tuple
9
10import torch
11from transformers import BatchEncoding
12
13from ..base import LightningIRModule, LightningIROutput
14from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
15from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction
16from ..retrieve import SearchConfig, Searcher
17from .config import BiEncoderConfig
18from .model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
19from .tokenizer import BiEncoderTokenizer
20
21
[docs]
22class BiEncoderModule(LightningIRModule):
[docs]
23 def __init__(
24 self,
25 model_name_or_path: str | None = None,
26 config: BiEncoderConfig | None = None,
27 model: BiEncoderModel | None = None,
28 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
29 evaluation_metrics: Sequence[str] | None = None,
30 index_dir: Path | None = None,
31 search_config: SearchConfig | None = None,
32 ):
33 """:class:`.LightningIRModule` for bi-encoder models. It contains a :class:`.BiEncoderModel` and a
34 :class:`.BiEncoderTokenizer` and implements the training, validation, and testing steps for the model.
35
36 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
37 :type model_name_or_path: str | None, optional
38 :param config: BiEncoderConfig to apply when loading from backbone model, defaults to None
39 :type config: BiEncoderConfig | None, optional
40 :param model: Already instantiated BiEncoderModel, defaults to None
41 :type model: BiEncoderModel | None, optional
42 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
43 loss function, defaults to None
44 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
45 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
46 testing, defaults to None
47 :type evaluation_metrics: Sequence[str] | None, optional
48 :param index_dir: Path to an index used for retrieval, defaults to None
49 :type index_dir: Path | None, optional
50 :param search_config: Configuration to use during retrieval, defaults to None
51 :type search_config: SearchConfig | None, optional
52 """
53 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics)
54 self.model: BiEncoderModel
55 self.config: BiEncoderConfig
56 self.tokenizer: BiEncoderTokenizer
57 self.scoring_function = self.model.scoring_function
58 if self.config.add_marker_tokens and len(self.tokenizer) > self.config.vocab_size:
59 self.model.resize_token_embeddings(len(self.tokenizer), 8)
60 self._searcher = None
61 self.search_config = search_config
62 self.index_dir = index_dir
63
64 @property
65 def searcher(self) -> Searcher | None:
66 """Searcher used for retrieval if `index_dir` and `search_config` are set.
67
68 :return: Searcher class
69 :rtype: Searcher | None
70 """
71 return self._searcher
72
73 @searcher.setter
74 def searcher(self, searcher: Searcher):
75 self._searcher = searcher
76
77 def _init_searcher(self) -> None:
78 if self.search_config is not None and self.index_dir is not None:
79 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self)
80
[docs]
81 def on_test_start(self) -> None:
82 """Called at the beginning of testing. Initializes the searcher if `index_dir` and `search_config` are set."""
83 self._init_searcher()
84 return super().on_test_start()
85
[docs]
86 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput:
87 """Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If
88 the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is
89 computed using the :attr:`.scoring_function`. If the batch is an :class:`.IndexBatch`, only document embeddings
90 are comuputed. If the batch is a :class:`.SearchBatch`, only query embeddings are computed and
91 the model will additionally retrieve documents if :attr:`.searcher` is set.
92
93 :param batch: Input batch containg
94 :type batch: RankBatch | IndexBatch | SearchBatch
95 :raises ValueError: If the input batch contains neither queries nor documents
96 :return: Output of the model
97 :rtype: BiEncoderOutput
98 """
99 queries = getattr(batch, "queries", None)
100 docs = getattr(batch, "docs", None)
101 num_docs = None
102 if isinstance(batch, RankBatch):
103 num_docs = None if docs is None else [len(d) for d in docs]
104 docs = [d for nested in docs for d in nested] if docs is not None else None
105 encodings = self.prepare_input(queries, docs, num_docs)
106
107 if not encodings:
108 raise ValueError("No encodings were generated.")
109 output = self.model.forward(
110 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs
111 )
112 if isinstance(batch, SearchBatch) and self.searcher is not None:
113 scores, doc_ids, num_docs = self.searcher.search(output)
114 output.scores = scores
115 cum_num_docs = [0] + [sum(num_docs[: i + 1]) for i in range(len(num_docs))]
116 doc_ids = tuple(tuple(doc_ids[cum_num_docs[i] : cum_num_docs[i + 1]]) for i in range(len(num_docs)))
117 batch.doc_ids = doc_ids
118 return output
119
[docs]
120 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput:
121 """Computes relevance scores for queries and documents.
122
123 :param queries: Queries to score
124 :type queries: Sequence[str]
125 :param docs: Documents to score
126 :type docs: Sequence[Sequence[str]]
127 :return: Model output
128 :rtype: BiEncoderOutput
129 """
130 return super().score(queries, docs)
131
132 def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[torch.Tensor]:
133 """Computes the losses for a training batch."""
134 if self.loss_functions is None:
135 raise ValueError("Loss function is not set")
136
137 if (
138 batch.targets is None
139 or output.query_embeddings is None
140 or output.doc_embeddings is None
141 or output.scores is None
142 ):
143 raise ValueError(
144 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch"
145 )
146
147 num_queries = len(batch.queries)
148 output.scores = output.scores.view(num_queries, -1)
149 batch.targets = batch.targets.view(*output.scores.shape, -1)
150 losses = []
151 for loss_function, _ in self.loss_functions:
152 if isinstance(loss_function, InBatchLossFunction):
153 pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch)
154 ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries)
155 ib_scores = self.model.score(output.query_embeddings, ib_doc_embeddings)
156 ib_scores = ib_scores.view(num_queries, -1)
157 losses.append(loss_function.compute_loss(LightningIROutput(ib_scores)))
158 elif isinstance(loss_function, EmbeddingLossFunction):
159 losses.append(loss_function.compute_loss(output))
160 elif isinstance(loss_function, ScoringLossFunction):
161 losses.append(loss_function.compute_loss(output, batch))
162 else:
163 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}")
164 if self.config.sparsification is not None:
165 query_num_nonzero = (
166 torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0]
167 )
168 doc_num_nonzero = (
169 torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0]
170 )
171 self.log("query_num_nonzero", query_num_nonzero)
172 self.log("doc_num_nonzero", doc_num_nonzero)
173 return losses
174
175 def _get_ib_doc_embeddings(
176 self,
177 embeddings: BiEncoderEmbedding,
178 pos_idcs: torch.Tensor,
179 neg_idcs: torch.Tensor,
180 num_queries: int,
181 ) -> BiEncoderEmbedding:
182 """Gets the in-batch document embeddings for a training batch."""
183 _, num_embs, emb_dim = embeddings.embeddings.shape
184 ib_embeddings = torch.cat(
185 [
186 embeddings.embeddings[pos_idcs].view(num_queries, -1, num_embs, emb_dim),
187 embeddings.embeddings[neg_idcs].view(num_queries, -1, num_embs, emb_dim),
188 ],
189 dim=1,
190 ).view(-1, num_embs, emb_dim)
191 ib_scoring_mask = torch.cat(
192 [
193 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, num_embs),
194 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, num_embs),
195 ],
196 dim=1,
197 ).view(-1, num_embs)
198 ib_encoding = {}
199 for key, value in embeddings.encoding.items():
200 seq_len = value.shape[-1]
201 ib_encoding[key] = torch.cat(
202 [value[pos_idcs].view(num_queries, -1, seq_len), value[neg_idcs].view(num_queries, -1, seq_len)],
203 dim=1,
204 ).view(-1, seq_len)
205 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask, BatchEncoding(ib_encoding))
206
[docs]
207 def validation_step(
208 self,
209 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch,
210 batch_idx: int,
211 dataloader_idx: int = 0,
212 ) -> BiEncoderOutput:
213 """Handles the validation step for the model.
214
215 :param batch: Batch of validation or testing data
216 :type batch: TrainBatch | IndexBatch | SearchBatch | RankBatch
217 :param batch_idx: Index of the batch
218 :type batch_idx: int
219 :param dataloader_idx: Index of the dataloader, defaults to 0
220 :type dataloader_idx: int, optional
221 :return: Model output
222 :rtype: BiEncoderOutput
223 """
224 if isinstance(batch, IndexBatch):
225 return self.forward(batch)
226 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)):
227 return super().validation_step(batch, batch_idx, dataloader_idx)
228 raise ValueError(f"Unknown batch type {type(batch)}")