Source code for lightning_ir.cross_encoder.model
1"""
2Model module for cross-encoder models.
3
4This module defines the model class used to implement cross-encoder models.
5"""
6
7from dataclasses import dataclass
8from typing import Type
9
10import torch
11from transformers import BatchEncoding
12
13from ..base import LightningIRModel, LightningIROutput
14from ..base.model import batch_encoding_wrapper
15from . import CrossEncoderConfig
16
17
[docs]
18@dataclass
19class CrossEncoderOutput(LightningIROutput):
20 """Dataclass containing the output of a cross-encoder model"""
21
22 embeddings: torch.Tensor | None = None
23 """Joint query-document embeddings"""
24
25
[docs]
26class CrossEncoderModel(LightningIRModel):
27 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
28 """Configuration class for cross-encoder models."""
29
[docs]
30 def __init__(self, config: CrossEncoderConfig, *args, **kwargs):
31 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are
32 aggragated into a single vector and fed to a linear layer which computes a final relevance score.
33
34 :param config: Configuration for the cross-encoder model
35 :type config: CrossEncoderConfig
36 """
37 super().__init__(config, *args, **kwargs)
38 self.config: CrossEncoderConfig
39 self.linear = torch.nn.Linear(config.hidden_size, 1, bias=config.linear_bias)
40
[docs]
41 @batch_encoding_wrapper
42 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
43 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
44 score.
45
46 :param encoding: Tokenizer encoding for the joint query-document input sequence
47 :type encoding: BatchEncoding
48 :return: Output of the model
49 :rtype: CrossEncoderOutput
50 """
51 embeddings = self._backbone_forward(**encoding).last_hidden_state
52 embeddings = self._pooling(
53 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy
54 )
55 scores = self.linear(embeddings).view(-1)
56 return CrossEncoderOutput(scores=scores, embeddings=embeddings)