Source code for lightning_ir.bi_encoder.tokenizer

  1"""
  2Tokenizer module for bi-encoder models.
  3
  4This module contains the tokenizer class bi-encoder models.
  5"""
  6
  7import warnings
  8from typing import Dict, Sequence, Type
  9
 10from tokenizers.processors import TemplateProcessing
 11from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast
 12
 13from ..base import LightningIRTokenizer
 14from .config import BiEncoderConfig
 15
 16
[docs] 17class BiEncoderTokenizer(LightningIRTokenizer): 18 19 config_class: Type[BiEncoderConfig] = BiEncoderConfig 20 """Configuration class for the tokenizer.""" 21 22 QUERY_TOKEN: str = "[QUE]" 23 """Token to mark a query sequence.""" 24 DOC_TOKEN: str = "[DOC]" 25 """Token to mark a document sequence.""" 26
[docs] 27 def __init__( 28 self, 29 *args, 30 query_expansion: bool = False, 31 query_length: int = 32, 32 attend_to_query_expanded_tokens: bool = False, 33 doc_expansion: bool = False, 34 doc_length: int = 512, 35 attend_to_doc_expanded_tokens: bool = False, 36 add_marker_tokens: bool = True, 37 **kwargs, 38 ): 39 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally 40 adds marker tokens are added to encoded input sequences. 41 42 :param query_expansion: Whether to expand queries with mask tokens, defaults to False 43 :type query_expansion: bool, optional 44 :param query_length: Maximum query length in number of tokens, defaults to 32 45 :type query_length: int, optional 46 :param attend_to_query_expanded_tokens: Whether to let non-expanded query tokens be able to attend to mask 47 expanded query tokens, defaults to False 48 :type attend_to_query_expanded_tokens: bool, optional 49 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False 50 :type doc_expansion: bool, optional 51 :param doc_length: Maximum document length in number of tokens, defaults to 512 52 :type doc_length: int, optional 53 :param attend_to_doc_expanded_tokens: Whether to let non-expanded document tokens be able to attend to 54 mask expanded document tokens, defaults to False 55 :type attend_to_doc_expanded_tokens: bool, optional 56 :param add_marker_tokens: Whether to add marker tokens to the query and document input sequences, 57 defaults to True 58 :type add_marker_tokens: bool, optional 59 :raises ValueError: If add_marker_tokens is True and a non-supported tokenizer is used 60 """ 61 super().__init__( 62 *args, 63 query_expansion=query_expansion, 64 query_length=query_length, 65 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens, 66 doc_expansion=doc_expansion, 67 doc_length=doc_length, 68 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens, 69 add_marker_tokens=add_marker_tokens, 70 **kwargs, 71 ) 72 self.query_expansion = query_expansion 73 self.query_length = query_length 74 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens 75 self.doc_expansion = doc_expansion 76 self.doc_length = doc_length 77 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens 78 self.add_marker_tokens = add_marker_tokens 79 80 self.query_post_processor: TemplateProcessing | None = None 81 self.doc_post_processor: TemplateProcessing | None = None 82 if add_marker_tokens: 83 # TODO support other tokenizers 84 if not isinstance(self, (BertTokenizer, BertTokenizerFast)): 85 raise ValueError("Adding marker tokens is only supported for BertTokenizer.") 86 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True) 87 self.query_post_processor = TemplateProcessing( 88 single=f"[CLS] {self.QUERY_TOKEN} $0 [SEP]", 89 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1", 90 special_tokens=[ 91 ("[CLS]", self.cls_token_id), 92 ("[SEP]", self.sep_token_id), 93 (self.QUERY_TOKEN, self.query_token_id), 94 (self.DOC_TOKEN, self.doc_token_id), 95 ], 96 ) 97 self.doc_post_processor = TemplateProcessing( 98 single=f"[CLS] {self.DOC_TOKEN} $0 [SEP]", 99 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1", 100 special_tokens=[ 101 ("[CLS]", self.cls_token_id), 102 ("[SEP]", self.sep_token_id), 103 (self.QUERY_TOKEN, self.query_token_id), 104 (self.DOC_TOKEN, self.doc_token_id), 105 ], 106 )
107 108 @property 109 def query_token_id(self) -> int | None: 110 """The token id of the query token if marker tokens are added. 111 112 :return: Token id of the query token 113 :rtype: int | None 114 """ 115 if self.QUERY_TOKEN in self.added_tokens_encoder: 116 return self.added_tokens_encoder[self.QUERY_TOKEN] 117 return None 118 119 @property 120 def doc_token_id(self) -> int | None: 121 """The token id of the document token if marker tokens are added. 122 123 :return: Token id of the document token 124 :rtype: int | None 125 """ 126 if self.DOC_TOKEN in self.added_tokens_encoder: 127 return self.added_tokens_encoder[self.DOC_TOKEN] 128 return None 129 130 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding: 131 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and 132 :meth:`.tokenize_doc` methods instead. 133 134 .. PretrainedTokenizer.__call__: \ 135https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__ 136 137 :param text: Text to tokenize 138 :type text: str | Sequence[str] 139 :param warn: Set to false to silence warning, defaults to True 140 :type warn: bool, optional 141 :return: Tokenized text 142 :rtype: BatchEncoding 143 """ 144 if warn: 145 warnings.warn( 146 "BiEncoderTokenizer is being directly called. Use tokenize_query and tokenize_doc to make sure " 147 "marker_tokens and query/doc expansion is applied." 148 ) 149 return super().__call__(*args, **kwargs) 150 151 def _encode( 152 self, 153 text: str | Sequence[str], 154 *args, 155 post_processor: TemplateProcessing | None = None, 156 **kwargs, 157 ) -> BatchEncoding: 158 """Encodes text with an optional post-processor.""" 159 orig_post_processor = self._tokenizer.post_processor 160 if post_processor is not None: 161 self._tokenizer.post_processor = post_processor 162 if kwargs.get("return_tensors", None) is not None: 163 kwargs["pad_to_multiple_of"] = 8 164 encoding = self(text, *args, warn=False, **kwargs) 165 self._tokenizer.post_processor = orig_post_processor 166 return encoding 167 168 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 169 """Applies mask expansion to the input encoding.""" 170 input_ids = encoding["input_ids"] 171 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 172 encoding["input_ids"] = input_ids 173 if attend_to_expanded_tokens: 174 encoding["attention_mask"].fill_(1) 175 return encoding 176
[docs] 177 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 178 """Tokenizes input queries. 179 180 :param queries: Query or queries to tokenize 181 :type queries: Sequence[str] | str 182 :return: Tokenized queries 183 :rtype: BatchEncoding 184 """ 185 kwargs["max_length"] = self.query_length 186 if self.query_expansion: 187 kwargs["padding"] = "max_length" 188 else: 189 kwargs["truncation"] = True 190 encoding = self._encode(queries, *args, post_processor=self.query_post_processor, **kwargs) 191 if self.query_expansion: 192 self._expand(encoding, self.attend_to_query_expanded_tokens) 193 return encoding
194
[docs] 195 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 196 """Tokenizes input documents. 197 198 :param docs: Document or documents to tokenize 199 :type docs: Sequence[str] | str 200 :return: Tokenized documents 201 :rtype: BatchEncoding 202 """ 203 kwargs["max_length"] = self.doc_length 204 if self.doc_expansion: 205 kwargs["padding"] = "max_length" 206 else: 207 kwargs["truncation"] = True 208 encoding = self._encode(docs, *args, post_processor=self.doc_post_processor, **kwargs) 209 if self.doc_expansion: 210 self._expand(encoding, self.attend_to_doc_expanded_tokens) 211 return encoding
212
[docs] 213 def tokenize( 214 self, 215 queries: str | Sequence[str] | None = None, 216 docs: str | Sequence[str] | None = None, 217 **kwargs, 218 ) -> Dict[str, BatchEncoding]: 219 """Tokenizes queries and documents. 220 221 :param queries: Queries to tokenize, defaults to None 222 :type queries: str | Sequence[str] | None, optional 223 :param docs: Documents to tokenize, defaults to None 224 :type docs: str | Sequence[str] | None, optional 225 :return: Dictionary of tokenized queries and documents 226 :rtype: Dict[str, BatchEncoding] 227 """ 228 encodings = {} 229 kwargs.pop("num_docs", None) 230 if queries is not None: 231 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs) 232 if docs is not None: 233 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs) 234 return encodings