Source code for lightning_ir.cross_encoder.tokenizer

  1"""
  2Tokenizer module for cross-encoder models.
  3
  4This module contains the tokenizer class cross-encoder models.
  5"""
  6
  7from typing import Dict, List, Sequence, Tuple, Type
  8
  9from transformers import BatchEncoding
 10
 11from ..base import LightningIRTokenizer
 12from .config import CrossEncoderConfig
 13
 14
[docs] 15class CrossEncoderTokenizer(LightningIRTokenizer): 16 17 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 18 """Configuration class for the tokenizer.""" 19
[docs] 20 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 21 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures 22 that the input sequences are of the correct length. 23 24 :param query_length: Maximum number of tokens per query, defaults to 32 25 :type query_length: int, optional 26 :param doc_length: Maximum number of tokens per document, defaults to 512 27 :type doc_length: int, optional 28 :type doc_length: int, optional 29 """ 30 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
31 32 def _truncate(self, text: Sequence[str], max_length: int) -> List[str]: 33 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings.""" 34 return self.batch_decode( 35 self( 36 text, 37 add_special_tokens=False, 38 truncation=True, 39 max_length=max_length, 40 return_attention_mask=False, 41 return_token_type_ids=False, 42 ).input_ids 43 ) 44 45 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]: 46 """Repeats queries to match the number of documents.""" 47 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])] 48 49 def _preprocess( 50 self, 51 queries: str | Sequence[str] | None, 52 docs: str | Sequence[str] | None, 53 num_docs: Sequence[int] | int | None, 54 ) -> Tuple[str | Sequence[str], str | Sequence[str]]: 55 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths.""" 56 if queries is None or docs is None: 57 raise ValueError("Both queries and docs must be provided.") 58 queries_is_string = isinstance(queries, str) 59 docs_is_string = isinstance(docs, str) 60 if queries_is_string != docs_is_string: 61 raise ValueError("Queries and docs must be both lists or both strings.") 62 if queries_is_string and docs_is_string: 63 queries = [queries] 64 docs = [docs] 65 truncated_queries = self._truncate(queries, self.query_length) 66 truncated_docs = self._truncate(docs, self.doc_length) 67 if not queries_is_string: 68 if num_docs is None: 69 if isinstance(num_docs, int): 70 num_docs = [num_docs] * len(queries) 71 else: 72 if len(docs) % len(queries) != 0: 73 raise ValueError("Number of documents must be divisible by the number of queries.") 74 num_docs = [len(docs) // len(queries) for _ in range(len(queries))] 75 repeated_queries = self._repeat_queries(truncated_queries, num_docs) 76 docs = truncated_docs 77 else: 78 repeated_queries = truncated_queries[0] 79 docs = truncated_docs[0] 80 return repeated_queries, docs 81
[docs] 82 def tokenize( 83 self, 84 queries: str | Sequence[str] | None = None, 85 docs: str | Sequence[str] | None = None, 86 num_docs: Sequence[int] | int | None = None, 87 **kwargs, 88 ) -> Dict[str, BatchEncoding]: 89 """Tokenizes queries and documents into a single sequence of tokens. 90 91 :param queries: Queries to tokenize, defaults to None 92 :type queries: str | Sequence[str] | None, optional 93 :param docs: Documents to tokenize, defaults to None 94 :type docs: str | Sequence[str] | None, optional 95 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 96 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 97 sequence contains one value per query specifying the number of documents for that query. If an integer, 98 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 99 the number of documents by the number of queries, defaults to None 100 :type num_docs: Sequence[int] | int | None, optional 101 :return: Tokenized query-document sequence 102 :rtype: Dict[str, BatchEncoding] 103 """ 104 repeated_queries, docs = self._preprocess(queries, docs, num_docs) 105 return_tensors = kwargs.get("return_tensors", None) 106 if return_tensors is not None: 107 kwargs["pad_to_multiple_of"] = 8 108 return {"encoding": self(repeated_queries, docs, **kwargs)}