Source code for lightning_ir.models.t5.tokenizer
1from typing import Dict, Literal, Sequence, Type
2
3from transformers import BatchEncoding
4
5from ...cross_encoder.tokenizer import CrossEncoderTokenizer
6from .config import T5CrossEncoderConfig
7
8
[docs]
9class T5CrossEncoderTokenizer(CrossEncoderTokenizer):
10
11 config_class: Type[T5CrossEncoderConfig] = T5CrossEncoderConfig
12
[docs]
13 def __init__(
14 self,
15 *args,
16 query_length: int = 32,
17 doc_length: int = 512,
18 decoder_strategy: Literal["mono", "rank"] = "mono",
19 **kwargs,
20 ):
21 super().__init__(
22 *args, query_length=query_length, doc_length=doc_length, decoder_strategy=decoder_strategy, **kwargs
23 )
24 self.decoder_strategy = decoder_strategy
25
26 def tokenize(
27 self,
28 queries: str | Sequence[str] | None = None,
29 docs: str | Sequence[str] | None = None,
30 num_docs: Sequence[int] | int | None = None,
31 **kwargs,
32 ) -> Dict[str, BatchEncoding]:
33 expanded_queries, docs = self._preprocess(queries, docs, num_docs)
34 if self.decoder_strategy == "mono":
35 pattern = "Query: {query} Document: {doc} Relevant:"
36 elif self.decoder_strategy == "rank":
37 pattern = "Query: {query} Document: {doc}"
38 else:
39 raise ValueError(f"Unknown decoder strategy: {self.decoder_strategy}")
40 input_texts = [pattern.format(query=query, doc=doc) for query, doc in zip(expanded_queries, docs)]
41
42 return_tensors = kwargs.get("return_tensors", None)
43 if return_tensors is not None:
44 kwargs["pad_to_multiple_of"] = 8
45 return {"encoding": self(input_texts, **kwargs)}