1"""
2Configuration module for bi-encoder models.
3
4This module defines the configuration class used to instantiate bi-encoder models.
5"""
6
7import json
8import os
9from os import PathLike
10from typing import Any, Dict, Literal, Sequence, Tuple
11
12from ..base import LightningIRConfig
13
14
[docs]
15class BiEncoderConfig(LightningIRConfig):
16 model_type: str = "bi-encoder"
17 """Model type for bi-encoder models."""
18
19 TOKENIZER_ARGS = LightningIRConfig.TOKENIZER_ARGS.union(
20 {
21 "query_expansion",
22 "attend_to_query_expanded_tokens",
23 "doc_expansion",
24 "attend_to_doc_expanded_tokens",
25 "add_marker_tokens",
26 }
27 )
28 """Arguments for the tokenizer."""
29
30 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union(
31 {
32 "similarity_function",
33 "query_pooling_strategy",
34 "query_mask_scoring_tokens",
35 "query_aggregation_function",
36 "doc_pooling_strategy",
37 "doc_mask_scoring_tokens",
38 "normalize",
39 "sparsification",
40 "embedding_dim",
41 "projection",
42 }
43 ).union(TOKENIZER_ARGS)
44 """Arguments added to the configuration."""
45
[docs]
46 def __init__(
47 self,
48 query_length: int = 32,
49 doc_length: int = 512,
50 similarity_function: Literal["cosine", "dot"] = "dot",
51 query_expansion: bool = False,
52 attend_to_query_expanded_tokens: bool = False,
53 query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean",
54 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
55 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum",
56 doc_expansion: bool = False,
57 attend_to_doc_expanded_tokens: bool = False,
58 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean",
59 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
60 normalize: bool = False,
61 sparsification: Literal["relu", "relu_log"] | None = None,
62 add_marker_tokens: bool = False,
63 embedding_dim: int = 768,
64 projection: Literal["linear", "linear_no_bias", "mlm"] | None = "linear",
65 **kwargs,
66 ):
67 """Configuration class for a bi-encoder model.
68
69 :param query_length: Maximum query length, defaults to 32
70 :type query_length: int, optional
71 :param doc_length: Maximum document length, defaults to 512
72 :type doc_length: int, optional
73 :param similarity_function: Similarity function to compute scores between query and document embeddings,
74 defaults to "dot"
75 :type similarity_function: Literal['cosine', 'dot'], optional
76 :param query_expansion: Whether to expand queries with mask tokens, defaults to False
77 :type query_expansion: bool, optional
78 :param attend_to_query_expanded_tokens: Whether to allow query tokens to attend to mask tokens,
79 defaults to False
80 :type attend_to_query_expanded_tokens: bool, optional
81 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "mean"
82 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
83 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None
84 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
85 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum"
86 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional
87 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False
88 :type doc_expansion: bool, optional
89 :param attend_to_doc_expanded_tokens: Whether to allow document tokens to attend to mask tokens,
90 defaults to False
91 :type attend_to_doc_expanded_tokens: bool, optional
92 :param doc_pooling_strategy: Whether andhow to pool document token embeddings, defaults to "mean"
93 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
94 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None
95 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
96 :param normalize: Whether to normalize query and document embeddings, defaults to False
97 :type normalize: bool, optional
98 :param sparsification: Whether and which sparsification function to apply, defaults to None
99 :type sparsification: Literal['relu', 'relu_log'] | None, optional
100 :param add_marker_tokens: Whether to add extra marker tokens [Q] / [D] to queries / documents, defaults to False
101 :type add_marker_tokens: bool, optional
102 :param embedding_dim: The output embedding dimension, defaults to 768
103 :type embedding_dim: int, optional
104 :param projection: Whether and how to project the output emeddings, defaults to "linear"
105 :type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional
106 """
107 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
108 self.similarity_function = similarity_function
109 self.query_expansion = query_expansion
110 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
111 self.query_pooling_strategy = query_pooling_strategy
112 self.query_mask_scoring_tokens = query_mask_scoring_tokens
113 self.query_aggregation_function = query_aggregation_function
114 self.doc_expansion = doc_expansion
115 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
116 self.doc_pooling_strategy = doc_pooling_strategy
117 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens
118 self.normalize = normalize
119 self.sparsification = sparsification
120 self.add_marker_tokens = add_marker_tokens
121 self.embedding_dim = embedding_dim
122 self.projection = projection
123
[docs]
124 def to_dict(self) -> Dict[str, Any]:
125 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments, the backbone
126 model type, and remove the mask scoring tokens.
127
128 .. _transformers.PretrainedConfig.to_dict: \
129https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict
130
131 :return: Configuration dictionary
132 :rtype: Dict[str, Any]
133 """
134
135 output = super().to_dict()
136 if "query_mask_scoring_tokens" in output:
137 output.pop("query_mask_scoring_tokens")
138 if "doc_mask_scoring_tokens" in output:
139 output.pop("doc_mask_scoring_tokens")
140 return output
141
[docs]
142 def save_pretrained(self, save_directory: str | PathLike, **kwargs) -> None:
143 """Overrides the transformers.PretrainedConfig.save_pretrained_ method to addtionally save the tokens which
144 should be maksed during scoring.
145
146 .. _transformers.PretrainedConfig.save_pretrained: \
147https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.save_pretrained
148
149 :param save_directory: Directory to save the configuration
150 :type save_directory: str | PathLike
151 """
152 super().save_pretrained(save_directory, **kwargs)
153 with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f:
154 json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f)
155
[docs]
156 @classmethod
157 def get_config_dict(
158 cls, pretrained_model_name_or_path: str | PathLike, **kwargs
159 ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
160 """Overrides the transformers.PretrainedConfig.get_config_dict_ method to load the tokens that should be masked
161 during scoring.
162
163 .. _transformers.PretrainedConfig.get_config_dict: \
164https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.get_config_dict
165
166 :param pretrained_model_name_or_path: Name or path of the pretrained model
167 :type pretrained_model_name_or_path: str | PathLike
168 :return: Configuration dictionary and additional keyword arguments
169 :rtype: Tuple[Dict[str, Any], Dict[str, Any]]
170 """
171 config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs)
172 mask_scoring_tokens = None
173 mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json")
174 if os.path.exists(mask_scoring_tokens_path):
175 with open(mask_scoring_tokens_path) as f:
176 mask_scoring_tokens = json.load(f)
177 config_dict["query_mask_scoring_tokens"] = mask_scoring_tokens["query"]
178 config_dict["doc_mask_scoring_tokens"] = mask_scoring_tokens["doc"]
179 return config_dict, kwargs