Source code for lightning_ir.cross_encoder.config

 1"""
 2Configuration module for cross-encoder models.
 3
 4This module defines the configuration class used to instantiate cross-encoder models.
 5"""
 6
 7from typing import Literal
 8
 9from ..base import LightningIRConfig
10
11
[docs] 12class CrossEncoderConfig(LightningIRConfig): 13 model_type: str = "cross-encoder" 14 """Model type for cross-encoder models.""" 15 16 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union({"pooling_strategy", "linear_bias"}) 17 """Arguments added to the configuration.""" 18
[docs] 19 def __init__( 20 self, 21 query_length: int = 32, 22 doc_length: int = 512, 23 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 24 linear_bias: bool = False, 25 **kwargs 26 ): 27 """Configuration class for a cross-encoder model 28 29 :param query_length: Maximum query length, defaults to 32 30 :type query_length: int, optional 31 :param doc_length: Maximum document length, defaults to 512 32 :type doc_length: int, optional 33 :param pooling_strategy: Pooling strategy to aggregate the contextualized embeddings into a single vector for 34 computing a relevance score, defaults to "first" 35 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional 36 :param linear_bias: Whether to use a bias in the prediction linear layer, defaults to False 37 :type linear_bias: bool, optional 38 """ 39 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 40 self.pooling_strategy = pooling_strategy 41 self.linear_bias = linear_bias