Source code for lightning_ir.models.splade
1from typing import Literal
2
3from ..bi_encoder import BiEncoderConfig, BiEncoderModel
4
5
[docs]
6class SpladeConfig(BiEncoderConfig):
7 model_type = "splade"
8
[docs]
9 def __init__(
10 self,
11 query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "max",
12 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "max",
13 projection: Literal["linear", "linear_no_bias", "mlm"] | None = "mlm",
14 sparsification: Literal["relu", "relu_log"] | None = "relu_log",
15 embedding_dim: int = 30522,
16 **kwargs,
17 ) -> None:
18 kwargs["query_expansion"] = False
19 kwargs["attend_to_query_expanded_tokens"] = False
20 kwargs["query_mask_scoring_tokens"] = None
21 kwargs["doc_expansion"] = False
22 kwargs["attend_to_doc_expanded_tokens"] = False
23 kwargs["doc_mask_scoring_tokens"] = None
24 kwargs["query_aggregation_function"] = "sum"
25 kwargs["normalize"] = False
26 kwargs["add_marker_tokens"] = False
27 super().__init__(
28 query_pooling_strategy=query_pooling_strategy,
29 doc_pooling_strategy=doc_pooling_strategy,
30 embedding_dim=embedding_dim,
31 projection=projection,
32 sparsification=sparsification,
33 **kwargs,
34 )
35
36
[docs]
37class SpladeModel(BiEncoderModel):
38 config_class = SpladeConfig