Source code for lightning_ir.bi_encoder.config

  1"""
  2Configuration module for bi-encoder models.
  3
  4This module defines the configuration class used to instantiate bi-encoder models.
  5"""
  6
  7import json
  8import os
  9from os import PathLike
 10from typing import Any, Dict, Literal, Sequence, Tuple
 11
 12from ..base import LightningIRConfig
 13
 14
[docs] 15class BiEncoderConfig(LightningIRConfig): 16 model_type: str = "bi-encoder" 17 """Model type for bi-encoder models.""" 18 19 TOKENIZER_ARGS = LightningIRConfig.TOKENIZER_ARGS.union( 20 { 21 "query_expansion", 22 "attend_to_query_expanded_tokens", 23 "doc_expansion", 24 "attend_to_doc_expanded_tokens", 25 "add_marker_tokens", 26 } 27 ) 28 """Arguments for the tokenizer.""" 29 30 ADDED_ARGS = LightningIRConfig.ADDED_ARGS.union( 31 { 32 "similarity_function", 33 "query_pooling_strategy", 34 "query_mask_scoring_tokens", 35 "query_aggregation_function", 36 "doc_pooling_strategy", 37 "doc_mask_scoring_tokens", 38 "normalize", 39 "sparsification", 40 "embedding_dim", 41 "projection", 42 } 43 ).union(TOKENIZER_ARGS) 44 """Arguments added to the configuration.""" 45
[docs] 46 def __init__( 47 self, 48 query_length: int = 32, 49 doc_length: int = 512, 50 similarity_function: Literal["cosine", "dot"] = "dot", 51 query_expansion: bool = False, 52 attend_to_query_expanded_tokens: bool = False, 53 query_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean", 54 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 55 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum", 56 doc_expansion: bool = False, 57 attend_to_doc_expanded_tokens: bool = False, 58 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] | None = "mean", 59 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 60 normalize: bool = False, 61 sparsification: Literal["relu", "relu_log"] | None = None, 62 add_marker_tokens: bool = False, 63 embedding_dim: int = 768, 64 projection: Literal["linear", "linear_no_bias", "mlm"] | None = "linear", 65 **kwargs, 66 ): 67 """Configuration class for a bi-encoder model. 68 69 :param query_length: Maximum query length, defaults to 32 70 :type query_length: int, optional 71 :param doc_length: Maximum document length, defaults to 512 72 :type doc_length: int, optional 73 :param similarity_function: Similarity function to compute scores between query and document embeddings, 74 defaults to "dot" 75 :type similarity_function: Literal['cosine', 'dot'], optional 76 :param query_expansion: Whether to expand queries with mask tokens, defaults to False 77 :type query_expansion: bool, optional 78 :param attend_to_query_expanded_tokens: Whether to allow query tokens to attend to mask tokens, 79 defaults to False 80 :type attend_to_query_expanded_tokens: bool, optional 81 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "mean" 82 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 83 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None 84 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 85 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum" 86 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional 87 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False 88 :type doc_expansion: bool, optional 89 :param attend_to_doc_expanded_tokens: Whether to allow document tokens to attend to mask tokens, 90 defaults to False 91 :type attend_to_doc_expanded_tokens: bool, optional 92 :param doc_pooling_strategy: Whether andhow to pool document token embeddings, defaults to "mean" 93 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 94 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None 95 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 96 :param normalize: Whether to normalize query and document embeddings, defaults to False 97 :type normalize: bool, optional 98 :param sparsification: Whether and which sparsification function to apply, defaults to None 99 :type sparsification: Literal['relu', 'relu_log'] | None, optional 100 :param add_marker_tokens: Whether to add extra marker tokens [Q] / [D] to queries / documents, defaults to False 101 :type add_marker_tokens: bool, optional 102 :param embedding_dim: The output embedding dimension, defaults to 768 103 :type embedding_dim: int, optional 104 :param projection: Whether and how to project the output emeddings, defaults to "linear" 105 :type projection: Literal['linear', 'linear_no_bias', 'mlm'] | None, optional 106 """ 107 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 108 self.similarity_function = similarity_function 109 self.query_expansion = query_expansion 110 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 111 self.query_pooling_strategy = query_pooling_strategy 112 self.query_mask_scoring_tokens = query_mask_scoring_tokens 113 self.query_aggregation_function = query_aggregation_function 114 self.doc_expansion = doc_expansion 115 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 116 self.doc_pooling_strategy = doc_pooling_strategy 117 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens 118 self.normalize = normalize 119 self.sparsification = sparsification 120 self.add_marker_tokens = add_marker_tokens 121 self.embedding_dim = embedding_dim 122 self.projection = projection
123
[docs] 124 def to_dict(self) -> Dict[str, Any]: 125 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments, the backbone 126 model type, and remove the mask scoring tokens. 127 128 .. _transformers.PretrainedConfig.to_dict: \ 129https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict 130 131 :return: Configuration dictionary 132 :rtype: Dict[str, Any] 133 """ 134 135 output = super().to_dict() 136 if "query_mask_scoring_tokens" in output: 137 output.pop("query_mask_scoring_tokens") 138 if "doc_mask_scoring_tokens" in output: 139 output.pop("doc_mask_scoring_tokens") 140 return output
141
[docs] 142 def save_pretrained(self, save_directory: str | PathLike, **kwargs) -> None: 143 """Overrides the transformers.PretrainedConfig.save_pretrained_ method to addtionally save the tokens which 144 should be maksed during scoring. 145 146 .. _transformers.PretrainedConfig.save_pretrained: \ 147https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.save_pretrained 148 149 :param save_directory: Directory to save the configuration 150 :type save_directory: str | PathLike 151 """ 152 super().save_pretrained(save_directory, **kwargs) 153 with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f: 154 json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f)
155
[docs] 156 @classmethod 157 def get_config_dict( 158 cls, pretrained_model_name_or_path: str | PathLike, **kwargs 159 ) -> Tuple[Dict[str, Any], Dict[str, Any]]: 160 """Overrides the transformers.PretrainedConfig.get_config_dict_ method to load the tokens that should be masked 161 during scoring. 162 163 .. _transformers.PretrainedConfig.get_config_dict: \ 164https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.get_config_dict 165 166 :param pretrained_model_name_or_path: Name or path of the pretrained model 167 :type pretrained_model_name_or_path: str | PathLike 168 :return: Configuration dictionary and additional keyword arguments 169 :rtype: Tuple[Dict[str, Any], Dict[str, Any]] 170 """ 171 config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs) 172 mask_scoring_tokens = None 173 mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json") 174 if os.path.exists(mask_scoring_tokens_path): 175 with open(mask_scoring_tokens_path) as f: 176 mask_scoring_tokens = json.load(f) 177 config_dict["query_mask_scoring_tokens"] = mask_scoring_tokens["query"] 178 config_dict["doc_mask_scoring_tokens"] = mask_scoring_tokens["doc"] 179 return config_dict, kwargs