1"""
2Tokenizer module for bi-encoder models.
3
4This module contains the tokenizer class bi-encoder models.
5"""
6
7import warnings
8from typing import Dict, Sequence, Type
9
10from tokenizers.processors import TemplateProcessing
11from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast
12
13from ..base import LightningIRTokenizer
14from .config import BiEncoderConfig
15
16
[docs]
17class BiEncoderTokenizer(LightningIRTokenizer):
18
19 config_class: Type[BiEncoderConfig] = BiEncoderConfig
20 """Configuration class for the tokenizer."""
21
22 QUERY_TOKEN: str = "[QUE]"
23 """Token to mark a query sequence."""
24 DOC_TOKEN: str = "[DOC]"
25 """Token to mark a document sequence."""
26
[docs]
27 def __init__(
28 self,
29 *args,
30 query_expansion: bool = False,
31 query_length: int = 32,
32 attend_to_query_expanded_tokens: bool = False,
33 doc_expansion: bool = False,
34 doc_length: int = 512,
35 attend_to_doc_expanded_tokens: bool = False,
36 add_marker_tokens: bool = True,
37 **kwargs,
38 ):
39 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally
40 adds marker tokens are added to encoded input sequences.
41
42 :param query_expansion: Whether to expand queries with mask tokens, defaults to False
43 :type query_expansion: bool, optional
44 :param query_length: Maximum query length in number of tokens, defaults to 32
45 :type query_length: int, optional
46 :param attend_to_query_expanded_tokens: Whether to let non-expanded query tokens be able to attend to mask
47 expanded query tokens, defaults to False
48 :type attend_to_query_expanded_tokens: bool, optional
49 :param doc_expansion: Whether to expand documents with mask tokens, defaults to False
50 :type doc_expansion: bool, optional
51 :param doc_length: Maximum document length in number of tokens, defaults to 512
52 :type doc_length: int, optional
53 :param attend_to_doc_expanded_tokens: Whether to let non-expanded document tokens be able to attend to
54 mask expanded document tokens, defaults to False
55 :type attend_to_doc_expanded_tokens: bool, optional
56 :param add_marker_tokens: Whether to add marker tokens to the query and document input sequences,
57 defaults to True
58 :type add_marker_tokens: bool, optional
59 :raises ValueError: If add_marker_tokens is True and a non-supported tokenizer is used
60 """
61 super().__init__(
62 *args,
63 query_expansion=query_expansion,
64 query_length=query_length,
65 attend_to_query_expanded_tokens=attend_to_query_expanded_tokens,
66 doc_expansion=doc_expansion,
67 doc_length=doc_length,
68 attend_to_doc_expanded_tokens=attend_to_doc_expanded_tokens,
69 add_marker_tokens=add_marker_tokens,
70 **kwargs,
71 )
72 self.query_expansion = query_expansion
73 self.query_length = query_length
74 self.attend_to_query_expanded_tokens = attend_to_query_expanded_tokens
75 self.doc_expansion = doc_expansion
76 self.doc_length = doc_length
77 self.attend_to_doc_expanded_tokens = attend_to_doc_expanded_tokens
78 self.add_marker_tokens = add_marker_tokens
79
80 self.query_post_processor: TemplateProcessing | None = None
81 self.doc_post_processor: TemplateProcessing | None = None
82 if add_marker_tokens:
83 # TODO support other tokenizers
84 if not isinstance(self, (BertTokenizer, BertTokenizerFast)):
85 raise ValueError("Adding marker tokens is only supported for BertTokenizer.")
86 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True)
87 self.query_post_processor = TemplateProcessing(
88 single=f"[CLS] {self.QUERY_TOKEN} $0 [SEP]",
89 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1",
90 special_tokens=[
91 ("[CLS]", self.cls_token_id),
92 ("[SEP]", self.sep_token_id),
93 (self.QUERY_TOKEN, self.query_token_id),
94 (self.DOC_TOKEN, self.doc_token_id),
95 ],
96 )
97 self.doc_post_processor = TemplateProcessing(
98 single=f"[CLS] {self.DOC_TOKEN} $0 [SEP]",
99 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1",
100 special_tokens=[
101 ("[CLS]", self.cls_token_id),
102 ("[SEP]", self.sep_token_id),
103 (self.QUERY_TOKEN, self.query_token_id),
104 (self.DOC_TOKEN, self.doc_token_id),
105 ],
106 )
107
108 @property
109 def query_token_id(self) -> int | None:
110 """The token id of the query token if marker tokens are added.
111
112 :return: Token id of the query token
113 :rtype: int | None
114 """
115 if self.QUERY_TOKEN in self.added_tokens_encoder:
116 return self.added_tokens_encoder[self.QUERY_TOKEN]
117 return None
118
119 @property
120 def doc_token_id(self) -> int | None:
121 """The token id of the document token if marker tokens are added.
122
123 :return: Token id of the document token
124 :rtype: int | None
125 """
126 if self.DOC_TOKEN in self.added_tokens_encoder:
127 return self.added_tokens_encoder[self.DOC_TOKEN]
128 return None
129
130 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding:
131 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and
132 :meth:`.tokenize_doc` methods instead.
133
134 .. PretrainedTokenizer.__call__: \
135https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
136
137 :param text: Text to tokenize
138 :type text: str | Sequence[str]
139 :param warn: Set to false to silence warning, defaults to True
140 :type warn: bool, optional
141 :return: Tokenized text
142 :rtype: BatchEncoding
143 """
144 if warn:
145 warnings.warn(
146 "BiEncoderTokenizer is being directly called. Use tokenize_query and tokenize_doc to make sure "
147 "marker_tokens and query/doc expansion is applied."
148 )
149 return super().__call__(*args, **kwargs)
150
151 def _encode(
152 self,
153 text: str | Sequence[str],
154 *args,
155 post_processor: TemplateProcessing | None = None,
156 **kwargs,
157 ) -> BatchEncoding:
158 """Encodes text with an optional post-processor."""
159 orig_post_processor = self._tokenizer.post_processor
160 if post_processor is not None:
161 self._tokenizer.post_processor = post_processor
162 if kwargs.get("return_tensors", None) is not None:
163 kwargs["pad_to_multiple_of"] = 8
164 encoding = self(text, *args, warn=False, **kwargs)
165 self._tokenizer.post_processor = orig_post_processor
166 return encoding
167
168 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
169 """Applies mask expansion to the input encoding."""
170 input_ids = encoding["input_ids"]
171 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
172 encoding["input_ids"] = input_ids
173 if attend_to_expanded_tokens:
174 encoding["attention_mask"].fill_(1)
175 return encoding
176
[docs]
177 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
178 """Tokenizes input queries.
179
180 :param queries: Query or queries to tokenize
181 :type queries: Sequence[str] | str
182 :return: Tokenized queries
183 :rtype: BatchEncoding
184 """
185 kwargs["max_length"] = self.query_length
186 if self.query_expansion:
187 kwargs["padding"] = "max_length"
188 else:
189 kwargs["truncation"] = True
190 encoding = self._encode(queries, *args, post_processor=self.query_post_processor, **kwargs)
191 if self.query_expansion:
192 self._expand(encoding, self.attend_to_query_expanded_tokens)
193 return encoding
194
[docs]
195 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
196 """Tokenizes input documents.
197
198 :param docs: Document or documents to tokenize
199 :type docs: Sequence[str] | str
200 :return: Tokenized documents
201 :rtype: BatchEncoding
202 """
203 kwargs["max_length"] = self.doc_length
204 if self.doc_expansion:
205 kwargs["padding"] = "max_length"
206 else:
207 kwargs["truncation"] = True
208 encoding = self._encode(docs, *args, post_processor=self.doc_post_processor, **kwargs)
209 if self.doc_expansion:
210 self._expand(encoding, self.attend_to_doc_expanded_tokens)
211 return encoding
212
[docs]
213 def tokenize(
214 self,
215 queries: str | Sequence[str] | None = None,
216 docs: str | Sequence[str] | None = None,
217 **kwargs,
218 ) -> Dict[str, BatchEncoding]:
219 """Tokenizes queries and documents.
220
221 :param queries: Queries to tokenize, defaults to None
222 :type queries: str | Sequence[str] | None, optional
223 :param docs: Documents to tokenize, defaults to None
224 :type docs: str | Sequence[str] | None, optional
225 :return: Dictionary of tokenized queries and documents
226 :rtype: Dict[str, BatchEncoding]
227 """
228 encodings = {}
229 kwargs.pop("num_docs", None)
230 if queries is not None:
231 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs)
232 if docs is not None:
233 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs)
234 return encodings