Source code for lightning_ir.cross_encoder.model

 1"""
 2Model module for cross-encoder models.
 3
 4This module defines the model class used to implement cross-encoder models.
 5"""
 6
 7from dataclasses import dataclass
 8from typing import Type
 9
10import torch
11from transformers import BatchEncoding
12
13from ..base import LightningIRModel, LightningIROutput
14from ..base.model import batch_encoding_wrapper
15from . import CrossEncoderConfig
16
17
[docs] 18@dataclass 19class CrossEncoderOutput(LightningIROutput): 20 """Dataclass containing the output of a cross-encoder model""" 21 22 embeddings: torch.Tensor | None = None 23 """Joint query-document embeddings"""
24 25
[docs] 26class CrossEncoderModel(LightningIRModel): 27 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 28 """Configuration class for cross-encoder models.""" 29
[docs] 30 def __init__(self, config: CrossEncoderConfig, *args, **kwargs): 31 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are 32 aggragated into a single vector and fed to a linear layer which computes a final relevance score. 33 34 :param config: Configuration for the cross-encoder model 35 :type config: CrossEncoderConfig 36 """ 37 super().__init__(config, *args, **kwargs) 38 self.config: CrossEncoderConfig 39 self.linear = torch.nn.Linear(config.hidden_size, 1, bias=config.linear_bias)
40
[docs] 41 @batch_encoding_wrapper 42 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 43 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 44 score. 45 46 :param encoding: Tokenizer encoding for the joint query-document input sequence 47 :type encoding: BatchEncoding 48 :return: Output of the model 49 :rtype: CrossEncoderOutput 50 """ 51 embeddings = self._backbone_forward(**encoding).last_hidden_state 52 embeddings = self._pooling( 53 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy 54 ) 55 scores = self.linear(embeddings).view(-1) 56 return CrossEncoderOutput(scores=scores, embeddings=embeddings)