Source code for lightning_ir.models.col
1from typing import Literal, Sequence
2
3from ..bi_encoder import BiEncoderConfig, BiEncoderModel
4
5
[docs]
6class ColConfig(BiEncoderConfig):
7 model_type = "col"
8
[docs]
9 def __init__(
10 self,
11 query_expansion: bool = True,
12 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = "punctuation",
13 embedding_dim: int = 128,
14 projection: Literal["linear", "linear_no_bias"] | None = "linear_no_bias",
15 **kwargs,
16 ) -> None:
17 kwargs["query_pooling_strategy"] = None
18 kwargs["doc_expansion"] = False
19 kwargs["attend_to_doc_expanded_tokens"] = False
20 kwargs["doc_pooling_strategy"] = None
21 super().__init__(
22 query_expansion=query_expansion,
23 doc_mask_scoring_tokens=doc_mask_scoring_tokens,
24 embedding_dim=embedding_dim,
25 projection=projection,
26 **kwargs,
27 )
28
29
[docs]
30class ColModel(BiEncoderModel):
31 config_class = ColConfig