1"""
2Model module for bi-encoder models.
3
4This module defines the model class used to implement bi-encoder models.
5"""
6
7import warnings
8from dataclasses import dataclass
9from functools import wraps
10from string import punctuation
11from typing import Callable, Iterable, Literal, Sequence, Tuple, Type, overload
12
13import torch
14from transformers import BatchEncoding
15from transformers.activations import ACT2FN
16
17from ..base import LightningIRModel, LightningIROutput
18from ..base.model import batch_encoding_wrapper
19from . import BiEncoderConfig
20
21
[docs]
22class MLMHead(torch.nn.Module):
[docs]
23 def __init__(self, config: BiEncoderConfig) -> None:
24 """Masked language model head. Projects the hidden states to the vocabulary size for MLM training.
25
26 :param config: Configuration for the bi-encoder model
27 :type config: BiEncoderConfig
28 """
29 super().__init__()
30 self.config = config
31 self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
32 if isinstance(config.hidden_act, str):
33 self.transform_act_fn = ACT2FN[config.hidden_act]
34 else:
35 self.transform_act_fn = config.hidden_act
36 self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
37 self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
38 self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size))
39
40 self.decoder.bias = self.bias
41
42 def _tie_weights(self):
43 self.decoder.bias = self.bias
44
[docs]
45 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
46 """Head forward pass.
47
48 :param hidden_states: Hidden states from the backbone model
49 :type hidden_states: torch.Tensor
50 :return: Projected hidden states to the vocabulary size
51 :rtype: torch.Tensor
52 """
53 hidden_states = self.dense(hidden_states)
54 hidden_states = self.transform_act_fn(hidden_states)
55 hidden_states = self.LayerNorm(hidden_states)
56 hidden_states = self.decoder(hidden_states)
57 return hidden_states
58
59
[docs]
60@dataclass
61class BiEncoderEmbedding:
62 """Dataclass containing embeddings and scoring mask for bi-encoder models."""
63
64 embeddings: torch.Tensor
65 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence
66 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings."""
67 scoring_mask: torch.Tensor
68 """Mask tensor designating which vectors should be ignored during scoring."""
69 encoding: BatchEncoding | None
70 """Tokenizer encodings used to generate the embeddings."""
71
72 @overload
73 def to(self, device: torch.device, /) -> "BiEncoderEmbedding": ...
74
75 @overload
76 def to(self, other: "BiEncoderEmbedding", /) -> "BiEncoderEmbedding": ...
77
[docs]
78 def to(self, device) -> "BiEncoderEmbedding":
79 """Moves the embeddings and scoring mask to the specified device.
80
81 :param device: Device to move the embeddings to or another BiEncoderEmbedding object to move to the same device
82 :type device: torch.device | BiEncoderEmbedding
83 :return: Self
84 :rtype: BiEncoderEmbedding
85 """
86 if isinstance(device, BiEncoderEmbedding):
87 device = device.device
88 self.embeddings = self.embeddings.to(device)
89 self.scoring_mask = self.scoring_mask.to(device)
90 self.encoding = self.encoding.to(device)
91 return self
92
93 @property
94 def device(self) -> torch.device:
95 """Returns the device of the embeddings.
96
97 :raises ValueError: If the embeddings and scoring_mask are not on the same device
98 :return: The device of the embeddings
99 :rtype: torch.device
100 """
101 if self.embeddings.device != self.scoring_mask.device:
102 raise ValueError("Embeddings and scoring_mask must be on the same device")
103 return self.embeddings.device
104
[docs]
105 def items(self) -> Iterable[Tuple[str, torch.Tensor]]:
106 """Iterates over the embeddings attributes and their values like `dict.items()`.
107
108 :yield: Tuple of attribute name and its value
109 :rtype: Iterator[Iterable[Tuple[str, torch.Tensor]]]
110 """
111 for field in self.__dataclass_fields__:
112 yield field, getattr(self, field)
113
114
[docs]
115@dataclass
116class BiEncoderOutput(LightningIROutput):
117 """Dataclass containing the output of a bi-encoder model."""
118
119 query_embeddings: BiEncoderEmbedding | None = None
120 """Query embeddings and scoring_mask generated by the model."""
121 doc_embeddings: BiEncoderEmbedding | None = None
122 """Document embeddings and scoring_mask generated by the model."""
123
124
[docs]
125class BiEncoderModel(LightningIRModel):
126
127 _tied_weights_keys = ["projection.decoder.bias", "projection.decoder.weight", "encoder.embed_tokens.weight"]
128 _keys_to_ignore_on_load_unexpected = [r"decoder"]
129
130 config_class: Type[BiEncoderConfig] = BiEncoderConfig
131 """Configuration class for the bi-encoder model."""
132
[docs]
133 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None:
134 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them
135 using a :class:`.ScoringFunction`. See :class:`.BiEncoderConfig` for configuration options.
136
137 :param config: Configuration for the bi-encoder model
138 :type config: BiEncoderConfig
139 :raises ValueError: If a projection is used but the hidden size of the backbone encoder and embedding dim of the
140 bi-encoder model do not match
141 """
142 super().__init__(config, *args, **kwargs)
143 self.config: BiEncoderConfig
144 self.scoring_function = ScoringFunction(self.config)
145 self.projection: torch.nn.Linear | MLMHead | None = None
146 if self.config.projection is not None:
147 if "linear" in self.config.projection:
148 self.projection = torch.nn.Linear(
149 self.config.hidden_size,
150 self.config.embedding_dim,
151 bias="no_bias" not in self.config.projection,
152 )
153 elif self.config.projection == "mlm":
154 self.projection = MLMHead(config)
155 else:
156 raise ValueError(f"Unknown projection {self.config.projection}")
157 else:
158 if self.config.embedding_dim != self.config.hidden_size:
159 warnings.warn(
160 "No projection is used but embedding_dim != hidden_size. "
161 "The output embeddings will not have embedding_size dimensions."
162 )
163
164 self.query_mask_scoring_input_ids: torch.Tensor | None = None
165 self.doc_mask_scoring_input_ids: torch.Tensor | None = None
166 self._add_mask_scoring_input_ids()
167
168 @classmethod
169 def _load_pretrained_model(
170 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
171 ):
172 if model.config.projection == "mlm":
173 has_base_model_prefix = any(s.startswith(model.base_model_prefix) for s in state_dict.keys())
174 prefix = model.base_model_prefix + "." if has_base_model_prefix else ""
175 for key in list(state_dict.keys()):
176 if key.startswith("cls"):
177 new_key = prefix + key.replace("cls.predictions", "projection").replace(".transform", "")
178 state_dict[new_key] = state_dict.pop(key)
179 loaded_keys[loaded_keys.index(key)] = new_key
180 return super()._load_pretrained_model(
181 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
182 )
183
[docs]
184 def get_output_embeddings(self) -> torch.nn.Module | None:
185 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no
186 MLM head is used for projection.
187
188 :return: Output embeddings of the model
189 :rtype: torch.nn.Module | None
190 """
191 if isinstance(self.projection, MLMHead):
192 return self.projection.decoder
193 return None
194
195 def _add_mask_scoring_input_ids(self) -> None:
196 """Adds the mask scoring input ids to the model if they are specified in the configuration."""
197 for sequence in ("query", "doc"):
198 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens")
199 if mask_scoring_tokens is None:
200 continue
201 if mask_scoring_tokens == "punctuation":
202 mask_scoring_tokens = list(punctuation)
203 try:
204 from transformers import AutoTokenizer
205
206 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
207 except OSError:
208 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.")
209 mask_scoring_input_ids = []
210 for token in mask_scoring_tokens:
211 if token not in tokenizer.vocab:
212 raise ValueError(f"Token {token} not in tokenizer vocab")
213 mask_scoring_input_ids.append(tokenizer.vocab[token])
214 setattr(
215 self,
216 f"{sequence}_mask_scoring_input_ids",
217 torch.tensor(mask_scoring_input_ids, dtype=torch.long),
218 )
219
[docs]
220 def forward(
221 self,
222 query_encoding: BatchEncoding | None,
223 doc_encoding: BatchEncoding | None,
224 num_docs: Sequence[int] | int | None = None,
225 ) -> BiEncoderOutput:
226 """Embeds queries and/or documents and computes relevance scores between them if both are provided.
227
228 :param query_encoding: Tokenizer encodings for the queries
229 :type query_encoding: BatchEncoding | None
230 :param doc_encoding: Tokenizer encodings for the documents
231 :type doc_encoding: BatchEncoding | None
232 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
233 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
234 sequence contains one value per query specifying the number of documents for that query. If an integer,
235 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
236 the number of documents by the number of queries, defaults to None
237 :type num_docs: Sequence[int] | int | None, optional
238 :return: Output of the model
239 :rtype: BiEncoderOutput
240 """
241 query_embeddings = None
242 if query_encoding is not None:
243 query_embeddings = self.encode_query(query_encoding)
244 doc_embeddings = None
245 if doc_encoding is not None:
246 doc_embeddings = self.encode_doc(doc_encoding)
247 scores = None
248 if doc_embeddings is not None and query_embeddings is not None:
249 scores = self.score(query_embeddings, doc_embeddings, num_docs)
250 return BiEncoderOutput(scores=scores, query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
251
[docs]
252 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
253 """Encodes tokenized queries.
254
255 :param encoding: Tokenizer encodings for the queries
256 :type encoding: BatchEncoding
257 :return: Query embeddings and scoring mask
258 :rtype: BiEncoderEmbedding
259 """
260 return self.encode(
261 encoding=encoding,
262 expansion=self.config.query_expansion,
263 pooling_strategy=self.config.query_pooling_strategy,
264 mask_scoring_input_ids=self.query_mask_scoring_input_ids,
265 )
266
[docs]
267 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding:
268 """Encodes tokenized documents.
269
270 :param encoding: Tokenizer encodings for the documents
271 :type encoding: BatchEncoding
272 :return: Query embeddings and scoring mask
273 :rtype: BiEncoderEmbedding
274 """
275 return self.encode(
276 encoding=encoding,
277 expansion=self.config.doc_expansion,
278 pooling_strategy=self.config.doc_pooling_strategy,
279 mask_scoring_input_ids=self.doc_mask_scoring_input_ids,
280 )
281
[docs]
282 @batch_encoding_wrapper
283 def encode(
284 self,
285 encoding: BatchEncoding,
286 expansion: bool = False,
287 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None,
288 mask_scoring_input_ids: torch.Tensor | None = None,
289 ) -> BiEncoderEmbedding:
290 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
291
292 :param encoding: Tokenizer encodings for the text sequence
293 :type encoding: BatchEncoding
294 :param expansion: Whether mask expansion was applied to the text sequence, defaults to False
295 :type expansion: bool, optional
296 :param pooling_strategy: Strategy to pool token embeddings into a single embedding. If None no pooling is
297 applied, defaults to None
298 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
299 :param mask_scoring_input_ids: Which token_ids to mask out during scoring, defaults to None
300 :type mask_scoring_input_ids: torch.Tensor | None, optional
301 :return: Embeddings and scoring mask
302 :rtype: BiEncoderEmbedding
303 """
304 embeddings = self._backbone_forward(**encoding).last_hidden_state
305 if self.projection is not None:
306 embeddings = self.projection(embeddings)
307 embeddings = self._sparsification(embeddings, self.config.sparsification)
308 embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy)
309 if self.config.normalize:
310 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
311 scoring_mask = self.scoring_mask(encoding, expansion, pooling_strategy, mask_scoring_input_ids)
312 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
313
[docs]
314 def scoring_mask(
315 self,
316 encoding: BatchEncoding,
317 expansion: bool = False,
318 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None,
319 mask_scoring_input_ids: torch.Tensor | None = None,
320 ) -> torch.Tensor:
321 """Computes a scoring for batched tokenized text sequences which is used in the scoring function to mask out
322 vectors during scoring.
323
324 :param encoding: Tokenizer encodings for the text sequence
325 :type encoding: BatchEncoding
326 :param expansion: Whether or not mask expansion was applied to the tokenized sequence, defaults to False
327 :type expansion: bool, optional
328 :param pooling_strategy: Which pooling strategy is pool the embeddings, defaults to None
329 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
330 :param mask_scoring_input_ids: Sequence of token_ids which should be masked during scoring, defaults to None
331 :type mask_scoring_input_ids: torch.Tensor | None, optional
332 :return: Scoring mask
333 :rtype: torch.Tensor
334 """
335 device = encoding["input_ids"].device
336 input_ids: torch.Tensor = encoding["input_ids"]
337 attention_mask: torch.Tensor = encoding["attention_mask"]
338 shape = input_ids.shape
339 if pooling_strategy is not None:
340 return torch.ones((shape[0], 1), dtype=torch.bool, device=device)
341 scoring_mask = attention_mask
342 if expansion or scoring_mask is None:
343 scoring_mask = torch.ones(shape, dtype=torch.bool, device=device)
344 scoring_mask = scoring_mask.bool()
345 if mask_scoring_input_ids is not None:
346 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1)
347 scoring_mask = scoring_mask & ~ignore_mask
348 return scoring_mask
349
[docs]
350 def score(
351 self,
352 query_embeddings: BiEncoderEmbedding,
353 doc_embeddings: BiEncoderEmbedding,
354 num_docs: Sequence[int] | int | None = None,
355 ) -> torch.Tensor:
356 """Compute relevance scores between queries and documents.
357
358 :param query_embeddings: Embeddings and scoring mask for the queries
359 :type query_embeddings: BiEncoderEmbedding
360 :param doc_embeddings: Embeddings and scoring mask for the documents
361 :type doc_embeddings: BiEncoderEmbedding
362 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
363 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
364 sequence contains one value per query specifying the number of documents for that query. If an integer,
365 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
366 the number of documents by the number of queries, defaults to None
367 :type num_docs: Sequence[int] | int | None, optional
368 :return: Relevance scores
369 :rtype: torch.Tensor
370 """
371 scores = self.scoring_function(query_embeddings, doc_embeddings, num_docs=num_docs)
372 return scores
373
374
375def _batch_scoring(
376 similarity_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
377) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
378 """Helper function to batch similarity functions to avoid memory issues with large batch sizes or high numbers
379 of documents per query."""
380 BATCH_SIZE = 1024
381
382 @wraps(similarity_function)
383 def batch_similarity_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
384 if x.shape[0] <= BATCH_SIZE:
385 return similarity_function(x, y)
386 out = torch.zeros(x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype)
387 for i in range(0, x.shape[0], BATCH_SIZE):
388 out[i : i + BATCH_SIZE] = similarity_function(x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE])
389 return out
390
391 return batch_similarity_function
392
393
[docs]
394class ScoringFunction(torch.nn.Module):
[docs]
395 def __init__(self, config: BiEncoderConfig) -> None:
396 """Scoring function for bi-encoder models. Computes similarity scores between query and document embeddings. For
397 multi-vector models, the scores are aggregated to a single score per query-document pair.
398
399 :param config: Configuration for the bi-encoder model
400 :type config: BiEncoderConfig
401 :raises ValueError: If the similarity function is not supported
402 """
403 super().__init__()
404 self.config = config
405 if self.config.similarity_function == "cosine":
406 self.similarity_function = self._cosine_similarity
407 elif self.config.similarity_function == "l2":
408 self.similarity_function = self._l2_similarity
409 elif self.config.similarity_function == "dot":
410 self.similarity_function = self._dot_similarity
411 else:
412 raise ValueError(f"Unknown similarity function {self.config.similarity_function}")
413 self.query_aggregation_function = self.config.query_aggregation_function
414
415 def _compute_similarity(
416 self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding
417 ) -> torch.Tensor:
418 """Computes the similarity score between all query and document embedding vector pairs."""
419 # TODO compute similarity only for non-masked values
420 similarity = self.similarity_function(query_embeddings.embeddings, doc_embeddings.embeddings)
421 return similarity
422
423 @staticmethod
424 @_batch_scoring
425 @torch.autocast(device_type="cuda", enabled=False)
426 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
427 return torch.nn.functional.cosine_similarity(x, y, dim=-1)
428
429 @staticmethod
430 @_batch_scoring
431 @torch.autocast(device_type="cuda", enabled=False)
432 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
433 return -1 * torch.cdist(x, y).squeeze(-2)
434
435 @staticmethod
436 @_batch_scoring
437 @torch.autocast(device_type="cuda", enabled=False)
438 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
439 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2)
440
441 def _parse_num_docs(
442 self,
443 query_embeddings: BiEncoderEmbedding,
444 doc_embeddings: BiEncoderEmbedding,
445 num_docs: int | Sequence[int] | None,
446 ) -> torch.Tensor:
447 """Helper function to parse the number of documents per query."""
448 batch_size = query_embeddings.embeddings.shape[0]
449 if isinstance(num_docs, int):
450 num_docs = [num_docs] * batch_size
451 if isinstance(num_docs, list):
452 if sum(num_docs) != doc_embeddings.embeddings.shape[0] or len(num_docs) != batch_size:
453 raise ValueError("Num docs does not match doc embeddings")
454 if num_docs is None:
455 if doc_embeddings.embeddings.shape[0] % batch_size != 0:
456 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided")
457 num_docs = [doc_embeddings.embeddings.shape[0] // batch_size] * batch_size
458 return torch.tensor(num_docs, device=query_embeddings.embeddings.device)
459
460 def _expand_query_embeddings(self, embeddings: BiEncoderEmbedding, num_docs: torch.Tensor) -> BiEncoderEmbedding:
461 """Helper function to expand query embeddings to match the number of documents per query."""
462 return BiEncoderEmbedding(
463 embeddings.embeddings.repeat_interleave(num_docs, dim=0).unsqueeze(2),
464 embeddings.scoring_mask.repeat_interleave(num_docs, dim=0).unsqueeze(2),
465 embeddings.encoding,
466 )
467
468 def _expand_doc_embeddings(self, embeddings: BiEncoderEmbedding, num_docs: torch.Tensor) -> BiEncoderEmbedding:
469 """Helper function to expand document embeddings to match the number documents per query."""
470 return BiEncoderEmbedding(
471 embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1), embeddings.encoding
472 )
473
474 def _aggregate(
475 self,
476 scores: torch.Tensor,
477 mask: torch.Tensor | None,
478 query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"] | None,
479 dim: int,
480 ) -> torch.Tensor:
481 """Helper function to aggregate similarity scores over query and document embeddings."""
482 if query_aggregation_function is None:
483 return scores
484 if query_aggregation_function == "max":
485 if mask is not None:
486 scores = scores.masked_fill(~mask, float("-inf"))
487 return scores.max(dim, keepdim=True).values
488 if query_aggregation_function == "sum":
489 if mask is not None:
490 scores = scores.masked_fill(~mask, 0)
491 return scores.sum(dim, keepdim=True)
492 if mask is None:
493 shape = list(scores.shape)
494 shape[dim] = 1
495 num_non_masked = torch.full(shape, scores.shape[dim], device=scores.device)
496 else:
497 num_non_masked = mask.sum(dim, keepdim=True)
498 if query_aggregation_function == "mean":
499 return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked)
500 if query_aggregation_function == "harmonic_mean":
501 return torch.where(
502 num_non_masked == 0,
503 0,
504 num_non_masked / (1 / scores).sum(dim, keepdim=True),
505 )
506 raise ValueError(f"Unknown aggregation {query_aggregation_function}")
507
[docs]
508 def forward(
509 self,
510 query_embeddings: BiEncoderEmbedding,
511 doc_embeddings: BiEncoderEmbedding,
512 num_docs: Sequence[int] | int | None = None,
513 ) -> torch.Tensor:
514 """Compute relevance scores between query and document embeddings.
515
516 :param query_embeddings: Embeddings and scoring mask for the queries
517 :type query_embeddings: BiEncoderEmbedding
518 :param doc_embeddings: Embeddings and scoring mask for the documents
519 :type doc_embeddings: BiEncoderEmbedding
520 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
521 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
522 sequence contains one value per query specifying the number of documents for that query. If an integer,
523 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
524 the number of documents by the number of queries, defaults to None
525 :type num_docs: Sequence[int] | int | None, optional
526 :return: Relevance scores
527 :rtype: torch.Tensor
528 """
529 num_docs_t = self._parse_num_docs(query_embeddings, doc_embeddings, num_docs)
530 query_embeddings = self._expand_query_embeddings(query_embeddings, num_docs_t)
531 doc_embeddings = self._expand_doc_embeddings(doc_embeddings, num_docs_t)
532 similarity = self._compute_similarity(query_embeddings, doc_embeddings)
533 scores = self._aggregate(similarity, doc_embeddings.scoring_mask, "max", -1)
534 scores = self._aggregate(scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2)
535 return scores[..., 0, 0]