Source code for lightning_ir.cross_encoder.config
1"""
2Configuration module for cross-encoder models.
3
4This module defines the configuration class used to instantiate cross-encoder models.
5"""
6
7from typing import Literal
8
9from ..base import LightningIRConfig
10
11
[docs]
12class CrossEncoderConfig(LightningIRConfig):
13 model_type: str = "cross-encoder"
14 """Model type for cross-encoder models."""
15
16 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union({"pooling_strategy", "linear_bias"})
17 """Arguments added to the configuration."""
18
[docs]
19 def __init__(
20 self,
21 query_length: int = 32,
22 doc_length: int = 512,
23 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
24 linear_bias: bool = False,
25 **kwargs
26 ):
27 """Configuration class for a cross-encoder model
28
29 :param query_length: Maximum query length, defaults to 32
30 :type query_length: int, optional
31 :param doc_length: Maximum document length, defaults to 512
32 :type doc_length: int, optional
33 :param pooling_strategy: Pooling strategy to aggregate the contextualized embeddings into a single vector for
34 computing a relevance score, defaults to "first"
35 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional
36 :param linear_bias: Whether to use a bias in the prediction linear layer, defaults to False
37 :type linear_bias: bool, optional
38 """
39 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
40 self.pooling_strategy = pooling_strategy
41 self.linear_bias = linear_bias