1"""
2Tokenizer module for cross-encoder models.
3
4This module contains the tokenizer class cross-encoder models.
5"""
6
7from typing import Dict, List, Sequence, Tuple, Type
8
9from transformers import BatchEncoding
10
11from ..base import LightningIRTokenizer
12from .config import CrossEncoderConfig
13
14
[docs]
15class CrossEncoderTokenizer(LightningIRTokenizer):
16
17 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
18 """Configuration class for the tokenizer."""
19
[docs]
20 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
21 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures
22 that the input sequences are of the correct length.
23
24 :param query_length: Maximum number of tokens per query, defaults to 32
25 :type query_length: int, optional
26 :param doc_length: Maximum number of tokens per document, defaults to 512
27 :type doc_length: int, optional
28 :type doc_length: int, optional
29 """
30 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
31
32 def _truncate(self, text: Sequence[str], max_length: int) -> List[str]:
33 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings."""
34 return self.batch_decode(
35 self(
36 text,
37 add_special_tokens=False,
38 truncation=True,
39 max_length=max_length,
40 return_attention_mask=False,
41 return_token_type_ids=False,
42 ).input_ids
43 )
44
45 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]:
46 """Repeats queries to match the number of documents."""
47 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])]
48
49 def _preprocess(
50 self,
51 queries: str | Sequence[str] | None,
52 docs: str | Sequence[str] | None,
53 num_docs: Sequence[int] | int | None,
54 ) -> Tuple[str | Sequence[str], str | Sequence[str]]:
55 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths."""
56 if queries is None or docs is None:
57 raise ValueError("Both queries and docs must be provided.")
58 queries_is_string = isinstance(queries, str)
59 docs_is_string = isinstance(docs, str)
60 if queries_is_string != docs_is_string:
61 raise ValueError("Queries and docs must be both lists or both strings.")
62 if queries_is_string and docs_is_string:
63 queries = [queries]
64 docs = [docs]
65 truncated_queries = self._truncate(queries, self.query_length)
66 truncated_docs = self._truncate(docs, self.doc_length)
67 if not queries_is_string:
68 if num_docs is None:
69 if isinstance(num_docs, int):
70 num_docs = [num_docs] * len(queries)
71 else:
72 if len(docs) % len(queries) != 0:
73 raise ValueError("Number of documents must be divisible by the number of queries.")
74 num_docs = [len(docs) // len(queries) for _ in range(len(queries))]
75 repeated_queries = self._repeat_queries(truncated_queries, num_docs)
76 docs = truncated_docs
77 else:
78 repeated_queries = truncated_queries[0]
79 docs = truncated_docs[0]
80 return repeated_queries, docs
81
[docs]
82 def tokenize(
83 self,
84 queries: str | Sequence[str] | None = None,
85 docs: str | Sequence[str] | None = None,
86 num_docs: Sequence[int] | int | None = None,
87 **kwargs,
88 ) -> Dict[str, BatchEncoding]:
89 """Tokenizes queries and documents into a single sequence of tokens.
90
91 :param queries: Queries to tokenize, defaults to None
92 :type queries: str | Sequence[str] | None, optional
93 :param docs: Documents to tokenize, defaults to None
94 :type docs: str | Sequence[str] | None, optional
95 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
96 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
97 sequence contains one value per query specifying the number of documents for that query. If an integer,
98 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
99 the number of documents by the number of queries, defaults to None
100 :type num_docs: Sequence[int] | int | None, optional
101 :return: Tokenized query-document sequence
102 :rtype: Dict[str, BatchEncoding]
103 """
104 repeated_queries, docs = self._preprocess(queries, docs, num_docs)
105 return_tensors = kwargs.get("return_tensors", None)
106 if return_tensors is not None:
107 kwargs["pad_to_multiple_of"] = 8
108 return {"encoding": self(repeated_queries, docs, **kwargs)}