Source code for lightning_ir.bi_encoder.model

  1"""
  2Model module for bi-encoder models.
  3
  4This module defines the model class used to implement bi-encoder models.
  5"""
  6
  7import warnings
  8from dataclasses import dataclass
  9from functools import wraps
 10from string import punctuation
 11from typing import Callable, Iterable, Literal, Sequence, Tuple, Type, overload
 12
 13import torch
 14from transformers import BatchEncoding
 15from transformers.activations import ACT2FN
 16
 17from ..base import LightningIRModel, LightningIROutput
 18from ..base.model import batch_encoding_wrapper
 19from . import BiEncoderConfig
 20
 21
[docs] 22class MLMHead(torch.nn.Module):
[docs] 23 def __init__(self, config: BiEncoderConfig) -> None: 24 """Masked language model head. Projects the hidden states to the vocabulary size for MLM training. 25 26 :param config: Configuration for the bi-encoder model 27 :type config: BiEncoderConfig 28 """ 29 super().__init__() 30 self.config = config 31 self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) 32 if isinstance(config.hidden_act, str): 33 self.transform_act_fn = ACT2FN[config.hidden_act] 34 else: 35 self.transform_act_fn = config.hidden_act 36 self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 37 self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 38 self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size)) 39 40 self.decoder.bias = self.bias
41 42 def _tie_weights(self): 43 self.decoder.bias = self.bias 44
[docs] 45 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 46 """Head forward pass. 47 48 :param hidden_states: Hidden states from the backbone model 49 :type hidden_states: torch.Tensor 50 :return: Projected hidden states to the vocabulary size 51 :rtype: torch.Tensor 52 """ 53 hidden_states = self.dense(hidden_states) 54 hidden_states = self.transform_act_fn(hidden_states) 55 hidden_states = self.LayerNorm(hidden_states) 56 hidden_states = self.decoder(hidden_states) 57 return hidden_states
58 59
[docs] 60@dataclass 61class BiEncoderEmbedding: 62 """Dataclass containing embeddings and scoring mask for bi-encoder models.""" 63 64 embeddings: torch.Tensor 65 """Embedding tensor generated by a bi-encoder model of shape [batch_size x seq_len x hidden_size]. The sequence 66 length varies depending on the pooling strategy and the hidden size varies depending on the projection settings.""" 67 scoring_mask: torch.Tensor 68 """Mask tensor designating which vectors should be ignored during scoring.""" 69 encoding: BatchEncoding | None 70 """Tokenizer encodings used to generate the embeddings.""" 71 72 @overload 73 def to(self, device: torch.device, /) -> "BiEncoderEmbedding": ... 74 75 @overload 76 def to(self, other: "BiEncoderEmbedding", /) -> "BiEncoderEmbedding": ... 77
[docs] 78 def to(self, device) -> "BiEncoderEmbedding": 79 """Moves the embeddings and scoring mask to the specified device. 80 81 :param device: Device to move the embeddings to or another BiEncoderEmbedding object to move to the same device 82 :type device: torch.device | BiEncoderEmbedding 83 :return: Self 84 :rtype: BiEncoderEmbedding 85 """ 86 if isinstance(device, BiEncoderEmbedding): 87 device = device.device 88 self.embeddings = self.embeddings.to(device) 89 self.scoring_mask = self.scoring_mask.to(device) 90 self.encoding = self.encoding.to(device) 91 return self
92 93 @property 94 def device(self) -> torch.device: 95 """Returns the device of the embeddings. 96 97 :raises ValueError: If the embeddings and scoring_mask are not on the same device 98 :return: The device of the embeddings 99 :rtype: torch.device 100 """ 101 if self.embeddings.device != self.scoring_mask.device: 102 raise ValueError("Embeddings and scoring_mask must be on the same device") 103 return self.embeddings.device 104
[docs] 105 def items(self) -> Iterable[Tuple[str, torch.Tensor]]: 106 """Iterates over the embeddings attributes and their values like `dict.items()`. 107 108 :yield: Tuple of attribute name and its value 109 :rtype: Iterator[Iterable[Tuple[str, torch.Tensor]]] 110 """ 111 for field in self.__dataclass_fields__: 112 yield field, getattr(self, field)
113 114
[docs] 115@dataclass 116class BiEncoderOutput(LightningIROutput): 117 """Dataclass containing the output of a bi-encoder model.""" 118 119 query_embeddings: BiEncoderEmbedding | None = None 120 """Query embeddings and scoring_mask generated by the model.""" 121 doc_embeddings: BiEncoderEmbedding | None = None 122 """Document embeddings and scoring_mask generated by the model."""
123 124
[docs] 125class BiEncoderModel(LightningIRModel): 126 127 _tied_weights_keys = ["projection.decoder.bias", "projection.decoder.weight", "encoder.embed_tokens.weight"] 128 _keys_to_ignore_on_load_unexpected = [r"decoder"] 129 130 config_class: Type[BiEncoderConfig] = BiEncoderConfig 131 """Configuration class for the bi-encoder model.""" 132
[docs] 133 def __init__(self, config: BiEncoderConfig, *args, **kwargs) -> None: 134 """A bi-encoder model that encodes queries and documents separately and computes a relevance score between them 135 using a :class:`.ScoringFunction`. See :class:`.BiEncoderConfig` for configuration options. 136 137 :param config: Configuration for the bi-encoder model 138 :type config: BiEncoderConfig 139 :raises ValueError: If a projection is used but the hidden size of the backbone encoder and embedding dim of the 140 bi-encoder model do not match 141 """ 142 super().__init__(config, *args, **kwargs) 143 self.config: BiEncoderConfig 144 self.scoring_function = ScoringFunction(self.config) 145 self.projection: torch.nn.Linear | MLMHead | None = None 146 if self.config.projection is not None: 147 if "linear" in self.config.projection: 148 self.projection = torch.nn.Linear( 149 self.config.hidden_size, 150 self.config.embedding_dim, 151 bias="no_bias" not in self.config.projection, 152 ) 153 elif self.config.projection == "mlm": 154 self.projection = MLMHead(config) 155 else: 156 raise ValueError(f"Unknown projection {self.config.projection}") 157 else: 158 if self.config.embedding_dim != self.config.hidden_size: 159 warnings.warn( 160 "No projection is used but embedding_dim != hidden_size. " 161 "The output embeddings will not have embedding_size dimensions." 162 ) 163 164 self.query_mask_scoring_input_ids: torch.Tensor | None = None 165 self.doc_mask_scoring_input_ids: torch.Tensor | None = None 166 self._add_mask_scoring_input_ids()
167 168 @classmethod 169 def _load_pretrained_model( 170 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 171 ): 172 if model.config.projection == "mlm": 173 has_base_model_prefix = any(s.startswith(model.base_model_prefix) for s in state_dict.keys()) 174 prefix = model.base_model_prefix + "." if has_base_model_prefix else "" 175 for key in list(state_dict.keys()): 176 if key.startswith("cls"): 177 new_key = prefix + key.replace("cls.predictions", "projection").replace(".transform", "") 178 state_dict[new_key] = state_dict.pop(key) 179 loaded_keys[loaded_keys.index(key)] = new_key 180 return super()._load_pretrained_model( 181 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 182 ) 183
[docs] 184 def get_output_embeddings(self) -> torch.nn.Module | None: 185 """Returns the output embeddings of the model for tieing the input and output embeddings. Returns None if no 186 MLM head is used for projection. 187 188 :return: Output embeddings of the model 189 :rtype: torch.nn.Module | None 190 """ 191 if isinstance(self.projection, MLMHead): 192 return self.projection.decoder 193 return None
194 195 def _add_mask_scoring_input_ids(self) -> None: 196 """Adds the mask scoring input ids to the model if they are specified in the configuration.""" 197 for sequence in ("query", "doc"): 198 mask_scoring_tokens = getattr(self.config, f"{sequence}_mask_scoring_tokens") 199 if mask_scoring_tokens is None: 200 continue 201 if mask_scoring_tokens == "punctuation": 202 mask_scoring_tokens = list(punctuation) 203 try: 204 from transformers import AutoTokenizer 205 206 tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path) 207 except OSError: 208 raise ValueError("Can't use token scoring masking if the checkpoint does not have a tokenizer.") 209 mask_scoring_input_ids = [] 210 for token in mask_scoring_tokens: 211 if token not in tokenizer.vocab: 212 raise ValueError(f"Token {token} not in tokenizer vocab") 213 mask_scoring_input_ids.append(tokenizer.vocab[token]) 214 setattr( 215 self, 216 f"{sequence}_mask_scoring_input_ids", 217 torch.tensor(mask_scoring_input_ids, dtype=torch.long), 218 ) 219
[docs] 220 def forward( 221 self, 222 query_encoding: BatchEncoding | None, 223 doc_encoding: BatchEncoding | None, 224 num_docs: Sequence[int] | int | None = None, 225 ) -> BiEncoderOutput: 226 """Embeds queries and/or documents and computes relevance scores between them if both are provided. 227 228 :param query_encoding: Tokenizer encodings for the queries 229 :type query_encoding: BatchEncoding | None 230 :param doc_encoding: Tokenizer encodings for the documents 231 :type doc_encoding: BatchEncoding | None 232 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 233 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 234 sequence contains one value per query specifying the number of documents for that query. If an integer, 235 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 236 the number of documents by the number of queries, defaults to None 237 :type num_docs: Sequence[int] | int | None, optional 238 :return: Output of the model 239 :rtype: BiEncoderOutput 240 """ 241 query_embeddings = None 242 if query_encoding is not None: 243 query_embeddings = self.encode_query(query_encoding) 244 doc_embeddings = None 245 if doc_encoding is not None: 246 doc_embeddings = self.encode_doc(doc_encoding) 247 scores = None 248 if doc_embeddings is not None and query_embeddings is not None: 249 scores = self.score(query_embeddings, doc_embeddings, num_docs) 250 return BiEncoderOutput(scores=scores, query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
251
[docs] 252 def encode_query(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 253 """Encodes tokenized queries. 254 255 :param encoding: Tokenizer encodings for the queries 256 :type encoding: BatchEncoding 257 :return: Query embeddings and scoring mask 258 :rtype: BiEncoderEmbedding 259 """ 260 return self.encode( 261 encoding=encoding, 262 expansion=self.config.query_expansion, 263 pooling_strategy=self.config.query_pooling_strategy, 264 mask_scoring_input_ids=self.query_mask_scoring_input_ids, 265 )
266
[docs] 267 def encode_doc(self, encoding: BatchEncoding) -> BiEncoderEmbedding: 268 """Encodes tokenized documents. 269 270 :param encoding: Tokenizer encodings for the documents 271 :type encoding: BatchEncoding 272 :return: Query embeddings and scoring mask 273 :rtype: BiEncoderEmbedding 274 """ 275 return self.encode( 276 encoding=encoding, 277 expansion=self.config.doc_expansion, 278 pooling_strategy=self.config.doc_pooling_strategy, 279 mask_scoring_input_ids=self.doc_mask_scoring_input_ids, 280 )
281
[docs] 282 @batch_encoding_wrapper 283 def encode( 284 self, 285 encoding: BatchEncoding, 286 expansion: bool = False, 287 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None, 288 mask_scoring_input_ids: torch.Tensor | None = None, 289 ) -> BiEncoderEmbedding: 290 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 291 292 :param encoding: Tokenizer encodings for the text sequence 293 :type encoding: BatchEncoding 294 :param expansion: Whether mask expansion was applied to the text sequence, defaults to False 295 :type expansion: bool, optional 296 :param pooling_strategy: Strategy to pool token embeddings into a single embedding. If None no pooling is 297 applied, defaults to None 298 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 299 :param mask_scoring_input_ids: Which token_ids to mask out during scoring, defaults to None 300 :type mask_scoring_input_ids: torch.Tensor | None, optional 301 :return: Embeddings and scoring mask 302 :rtype: BiEncoderEmbedding 303 """ 304 embeddings = self._backbone_forward(**encoding).last_hidden_state 305 if self.projection is not None: 306 embeddings = self.projection(embeddings) 307 embeddings = self._sparsification(embeddings, self.config.sparsification) 308 embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy) 309 if self.config.normalize: 310 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 311 scoring_mask = self.scoring_mask(encoding, expansion, pooling_strategy, mask_scoring_input_ids) 312 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
313
[docs] 314 def scoring_mask( 315 self, 316 encoding: BatchEncoding, 317 expansion: bool = False, 318 pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None, 319 mask_scoring_input_ids: torch.Tensor | None = None, 320 ) -> torch.Tensor: 321 """Computes a scoring for batched tokenized text sequences which is used in the scoring function to mask out 322 vectors during scoring. 323 324 :param encoding: Tokenizer encodings for the text sequence 325 :type encoding: BatchEncoding 326 :param expansion: Whether or not mask expansion was applied to the tokenized sequence, defaults to False 327 :type expansion: bool, optional 328 :param pooling_strategy: Which pooling strategy is pool the embeddings, defaults to None 329 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 330 :param mask_scoring_input_ids: Sequence of token_ids which should be masked during scoring, defaults to None 331 :type mask_scoring_input_ids: torch.Tensor | None, optional 332 :return: Scoring mask 333 :rtype: torch.Tensor 334 """ 335 device = encoding["input_ids"].device 336 input_ids: torch.Tensor = encoding["input_ids"] 337 attention_mask: torch.Tensor = encoding["attention_mask"] 338 shape = input_ids.shape 339 if pooling_strategy is not None: 340 return torch.ones((shape[0], 1), dtype=torch.bool, device=device) 341 scoring_mask = attention_mask 342 if expansion or scoring_mask is None: 343 scoring_mask = torch.ones(shape, dtype=torch.bool, device=device) 344 scoring_mask = scoring_mask.bool() 345 if mask_scoring_input_ids is not None: 346 ignore_mask = input_ids[..., None].eq(mask_scoring_input_ids.to(device)).any(-1) 347 scoring_mask = scoring_mask & ~ignore_mask 348 return scoring_mask
349
[docs] 350 def score( 351 self, 352 query_embeddings: BiEncoderEmbedding, 353 doc_embeddings: BiEncoderEmbedding, 354 num_docs: Sequence[int] | int | None = None, 355 ) -> torch.Tensor: 356 """Compute relevance scores between queries and documents. 357 358 :param query_embeddings: Embeddings and scoring mask for the queries 359 :type query_embeddings: BiEncoderEmbedding 360 :param doc_embeddings: Embeddings and scoring mask for the documents 361 :type doc_embeddings: BiEncoderEmbedding 362 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 363 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 364 sequence contains one value per query specifying the number of documents for that query. If an integer, 365 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 366 the number of documents by the number of queries, defaults to None 367 :type num_docs: Sequence[int] | int | None, optional 368 :return: Relevance scores 369 :rtype: torch.Tensor 370 """ 371 scores = self.scoring_function(query_embeddings, doc_embeddings, num_docs=num_docs) 372 return scores
373 374 375def _batch_scoring( 376 similarity_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] 377) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: 378 """Helper function to batch similarity functions to avoid memory issues with large batch sizes or high numbers 379 of documents per query.""" 380 BATCH_SIZE = 1024 381 382 @wraps(similarity_function) 383 def batch_similarity_function(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 384 if x.shape[0] <= BATCH_SIZE: 385 return similarity_function(x, y) 386 out = torch.zeros(x.shape[0], x.shape[1], y.shape[2], device=x.device, dtype=x.dtype) 387 for i in range(0, x.shape[0], BATCH_SIZE): 388 out[i : i + BATCH_SIZE] = similarity_function(x[i : i + BATCH_SIZE], y[i : i + BATCH_SIZE]) 389 return out 390 391 return batch_similarity_function 392 393
[docs] 394class ScoringFunction(torch.nn.Module):
[docs] 395 def __init__(self, config: BiEncoderConfig) -> None: 396 """Scoring function for bi-encoder models. Computes similarity scores between query and document embeddings. For 397 multi-vector models, the scores are aggregated to a single score per query-document pair. 398 399 :param config: Configuration for the bi-encoder model 400 :type config: BiEncoderConfig 401 :raises ValueError: If the similarity function is not supported 402 """ 403 super().__init__() 404 self.config = config 405 if self.config.similarity_function == "cosine": 406 self.similarity_function = self._cosine_similarity 407 elif self.config.similarity_function == "l2": 408 self.similarity_function = self._l2_similarity 409 elif self.config.similarity_function == "dot": 410 self.similarity_function = self._dot_similarity 411 else: 412 raise ValueError(f"Unknown similarity function {self.config.similarity_function}") 413 self.query_aggregation_function = self.config.query_aggregation_function
414 415 def _compute_similarity( 416 self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding 417 ) -> torch.Tensor: 418 """Computes the similarity score between all query and document embedding vector pairs.""" 419 # TODO compute similarity only for non-masked values 420 similarity = self.similarity_function(query_embeddings.embeddings, doc_embeddings.embeddings) 421 return similarity 422 423 @staticmethod 424 @_batch_scoring 425 @torch.autocast(device_type="cuda", enabled=False) 426 def _cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 427 return torch.nn.functional.cosine_similarity(x, y, dim=-1) 428 429 @staticmethod 430 @_batch_scoring 431 @torch.autocast(device_type="cuda", enabled=False) 432 def _l2_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 433 return -1 * torch.cdist(x, y).squeeze(-2) 434 435 @staticmethod 436 @_batch_scoring 437 @torch.autocast(device_type="cuda", enabled=False) 438 def _dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 439 return torch.matmul(x, y.transpose(-1, -2)).squeeze(-2) 440 441 def _parse_num_docs( 442 self, 443 query_embeddings: BiEncoderEmbedding, 444 doc_embeddings: BiEncoderEmbedding, 445 num_docs: int | Sequence[int] | None, 446 ) -> torch.Tensor: 447 """Helper function to parse the number of documents per query.""" 448 batch_size = query_embeddings.embeddings.shape[0] 449 if isinstance(num_docs, int): 450 num_docs = [num_docs] * batch_size 451 if isinstance(num_docs, list): 452 if sum(num_docs) != doc_embeddings.embeddings.shape[0] or len(num_docs) != batch_size: 453 raise ValueError("Num docs does not match doc embeddings") 454 if num_docs is None: 455 if doc_embeddings.embeddings.shape[0] % batch_size != 0: 456 raise ValueError("Docs are not evenly distributed in _batch, but no num_docs provided") 457 num_docs = [doc_embeddings.embeddings.shape[0] // batch_size] * batch_size 458 return torch.tensor(num_docs, device=query_embeddings.embeddings.device) 459 460 def _expand_query_embeddings(self, embeddings: BiEncoderEmbedding, num_docs: torch.Tensor) -> BiEncoderEmbedding: 461 """Helper function to expand query embeddings to match the number of documents per query.""" 462 return BiEncoderEmbedding( 463 embeddings.embeddings.repeat_interleave(num_docs, dim=0).unsqueeze(2), 464 embeddings.scoring_mask.repeat_interleave(num_docs, dim=0).unsqueeze(2), 465 embeddings.encoding, 466 ) 467 468 def _expand_doc_embeddings(self, embeddings: BiEncoderEmbedding, num_docs: torch.Tensor) -> BiEncoderEmbedding: 469 """Helper function to expand document embeddings to match the number documents per query.""" 470 return BiEncoderEmbedding( 471 embeddings.embeddings.unsqueeze(1), embeddings.scoring_mask.unsqueeze(1), embeddings.encoding 472 ) 473 474 def _aggregate( 475 self, 476 scores: torch.Tensor, 477 mask: torch.Tensor | None, 478 query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"] | None, 479 dim: int, 480 ) -> torch.Tensor: 481 """Helper function to aggregate similarity scores over query and document embeddings.""" 482 if query_aggregation_function is None: 483 return scores 484 if query_aggregation_function == "max": 485 if mask is not None: 486 scores = scores.masked_fill(~mask, float("-inf")) 487 return scores.max(dim, keepdim=True).values 488 if query_aggregation_function == "sum": 489 if mask is not None: 490 scores = scores.masked_fill(~mask, 0) 491 return scores.sum(dim, keepdim=True) 492 if mask is None: 493 shape = list(scores.shape) 494 shape[dim] = 1 495 num_non_masked = torch.full(shape, scores.shape[dim], device=scores.device) 496 else: 497 num_non_masked = mask.sum(dim, keepdim=True) 498 if query_aggregation_function == "mean": 499 return torch.where(num_non_masked == 0, 0, scores.sum(dim, keepdim=True) / num_non_masked) 500 if query_aggregation_function == "harmonic_mean": 501 return torch.where( 502 num_non_masked == 0, 503 0, 504 num_non_masked / (1 / scores).sum(dim, keepdim=True), 505 ) 506 raise ValueError(f"Unknown aggregation {query_aggregation_function}") 507
[docs] 508 def forward( 509 self, 510 query_embeddings: BiEncoderEmbedding, 511 doc_embeddings: BiEncoderEmbedding, 512 num_docs: Sequence[int] | int | None = None, 513 ) -> torch.Tensor: 514 """Compute relevance scores between query and document embeddings. 515 516 :param query_embeddings: Embeddings and scoring mask for the queries 517 :type query_embeddings: BiEncoderEmbedding 518 :param doc_embeddings: Embeddings and scoring mask for the documents 519 :type doc_embeddings: BiEncoderEmbedding 520 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 521 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 522 sequence contains one value per query specifying the number of documents for that query. If an integer, 523 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 524 the number of documents by the number of queries, defaults to None 525 :type num_docs: Sequence[int] | int | None, optional 526 :return: Relevance scores 527 :rtype: torch.Tensor 528 """ 529 num_docs_t = self._parse_num_docs(query_embeddings, doc_embeddings, num_docs) 530 query_embeddings = self._expand_query_embeddings(query_embeddings, num_docs_t) 531 doc_embeddings = self._expand_doc_embeddings(doc_embeddings, num_docs_t) 532 similarity = self._compute_similarity(query_embeddings, doc_embeddings) 533 scores = self._aggregate(similarity, doc_embeddings.scoring_mask, "max", -1) 534 scores = self._aggregate(scores, query_embeddings.scoring_mask, self.query_aggregation_function, -2) 535 return scores[..., 0, 0]