Source code for lightning_ir.bi_encoder.module

  1"""
  2Module module for bi-encoder models.
  3
  4This module defines the Lightning IR module class used to implement bi-encoder models.
  5"""
  6
  7from pathlib import Path
  8from typing import List, Sequence, Tuple
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ..base import LightningIRModule, LightningIROutput
 14from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
 15from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction
 16from ..retrieve import SearchConfig, Searcher
 17from .config import BiEncoderConfig
 18from .model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
 19from .tokenizer import BiEncoderTokenizer
 20
 21
[docs] 22class BiEncoderModule(LightningIRModule):
[docs] 23 def __init__( 24 self, 25 model_name_or_path: str | None = None, 26 config: BiEncoderConfig | None = None, 27 model: BiEncoderModel | None = None, 28 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 29 evaluation_metrics: Sequence[str] | None = None, 30 index_dir: Path | None = None, 31 search_config: SearchConfig | None = None, 32 ): 33 """:class:`.LightningIRModule` for bi-encoder models. It contains a :class:`.BiEncoderModel` and a 34 :class:`.BiEncoderTokenizer` and implements the training, validation, and testing steps for the model. 35 36 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None 37 :type model_name_or_path: str | None, optional 38 :param config: BiEncoderConfig to apply when loading from backbone model, defaults to None 39 :type config: BiEncoderConfig | None, optional 40 :param model: Already instantiated BiEncoderModel, defaults to None 41 :type model: BiEncoderModel | None, optional 42 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per 43 loss function, defaults to None 44 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional 45 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or 46 testing, defaults to None 47 :type evaluation_metrics: Sequence[str] | None, optional 48 :param index_dir: Path to an index used for retrieval, defaults to None 49 :type index_dir: Path | None, optional 50 :param search_config: Configuration to use during retrieval, defaults to None 51 :type search_config: SearchConfig | None, optional 52 """ 53 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics) 54 self.model: BiEncoderModel 55 self.config: BiEncoderConfig 56 self.tokenizer: BiEncoderTokenizer 57 self.scoring_function = self.model.scoring_function 58 if self.config.add_marker_tokens and len(self.tokenizer) > self.config.vocab_size: 59 self.model.resize_token_embeddings(len(self.tokenizer), 8) 60 self._searcher = None 61 self.search_config = search_config 62 self.index_dir = index_dir
63 64 @property 65 def searcher(self) -> Searcher | None: 66 """Searcher used for retrieval if `index_dir` and `search_config` are set. 67 68 :return: Searcher class 69 :rtype: Searcher | None 70 """ 71 return self._searcher 72 73 @searcher.setter 74 def searcher(self, searcher: Searcher): 75 self._searcher = searcher 76 77 def _init_searcher(self) -> None: 78 if self.search_config is not None and self.index_dir is not None: 79 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self) 80
[docs] 81 def on_test_start(self) -> None: 82 """Called at the beginning of testing. Initializes the searcher if `index_dir` and `search_config` are set.""" 83 self._init_searcher() 84 return super().on_test_start()
85
[docs] 86 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput: 87 """Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If 88 the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is 89 computed using the :attr:`.scoring_function`. If the batch is an :class:`.IndexBatch`, only document embeddings 90 are comuputed. If the batch is a :class:`.SearchBatch`, only query embeddings are computed and 91 the model will additionally retrieve documents if :attr:`.searcher` is set. 92 93 :param batch: Input batch containg 94 :type batch: RankBatch | IndexBatch | SearchBatch 95 :raises ValueError: If the input batch contains neither queries nor documents 96 :return: Output of the model 97 :rtype: BiEncoderOutput 98 """ 99 queries = getattr(batch, "queries", None) 100 docs = getattr(batch, "docs", None) 101 num_docs = None 102 if isinstance(batch, RankBatch): 103 num_docs = None if docs is None else [len(d) for d in docs] 104 docs = [d for nested in docs for d in nested] if docs is not None else None 105 encodings = self.prepare_input(queries, docs, num_docs) 106 107 if not encodings: 108 raise ValueError("No encodings were generated.") 109 output = self.model.forward( 110 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs 111 ) 112 if isinstance(batch, SearchBatch) and self.searcher is not None: 113 scores, doc_ids, num_docs = self.searcher.search(output) 114 output.scores = scores 115 cum_num_docs = [0] + [sum(num_docs[: i + 1]) for i in range(len(num_docs))] 116 doc_ids = tuple(tuple(doc_ids[cum_num_docs[i] : cum_num_docs[i + 1]]) for i in range(len(num_docs))) 117 batch.doc_ids = doc_ids 118 return output
119
[docs] 120 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput: 121 """Computes relevance scores for queries and documents. 122 123 :param queries: Queries to score 124 :type queries: Sequence[str] 125 :param docs: Documents to score 126 :type docs: Sequence[Sequence[str]] 127 :return: Model output 128 :rtype: BiEncoderOutput 129 """ 130 return super().score(queries, docs)
131 132 def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[torch.Tensor]: 133 """Computes the losses for a training batch.""" 134 if self.loss_functions is None: 135 raise ValueError("Loss function is not set") 136 137 if ( 138 batch.targets is None 139 or output.query_embeddings is None 140 or output.doc_embeddings is None 141 or output.scores is None 142 ): 143 raise ValueError( 144 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch" 145 ) 146 147 num_queries = len(batch.queries) 148 output.scores = output.scores.view(num_queries, -1) 149 batch.targets = batch.targets.view(*output.scores.shape, -1) 150 losses = [] 151 for loss_function, _ in self.loss_functions: 152 if isinstance(loss_function, InBatchLossFunction): 153 pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch) 154 ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries) 155 ib_scores = self.model.score(output.query_embeddings, ib_doc_embeddings) 156 ib_scores = ib_scores.view(num_queries, -1) 157 losses.append(loss_function.compute_loss(LightningIROutput(ib_scores))) 158 elif isinstance(loss_function, EmbeddingLossFunction): 159 losses.append(loss_function.compute_loss(output)) 160 elif isinstance(loss_function, ScoringLossFunction): 161 losses.append(loss_function.compute_loss(output, batch)) 162 else: 163 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}") 164 if self.config.sparsification is not None: 165 query_num_nonzero = ( 166 torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0] 167 ) 168 doc_num_nonzero = ( 169 torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0] 170 ) 171 self.log("query_num_nonzero", query_num_nonzero) 172 self.log("doc_num_nonzero", doc_num_nonzero) 173 return losses 174 175 def _get_ib_doc_embeddings( 176 self, 177 embeddings: BiEncoderEmbedding, 178 pos_idcs: torch.Tensor, 179 neg_idcs: torch.Tensor, 180 num_queries: int, 181 ) -> BiEncoderEmbedding: 182 """Gets the in-batch document embeddings for a training batch.""" 183 _, num_embs, emb_dim = embeddings.embeddings.shape 184 ib_embeddings = torch.cat( 185 [ 186 embeddings.embeddings[pos_idcs].view(num_queries, -1, num_embs, emb_dim), 187 embeddings.embeddings[neg_idcs].view(num_queries, -1, num_embs, emb_dim), 188 ], 189 dim=1, 190 ).view(-1, num_embs, emb_dim) 191 ib_scoring_mask = torch.cat( 192 [ 193 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, num_embs), 194 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, num_embs), 195 ], 196 dim=1, 197 ).view(-1, num_embs) 198 ib_encoding = {} 199 for key, value in embeddings.encoding.items(): 200 seq_len = value.shape[-1] 201 ib_encoding[key] = torch.cat( 202 [value[pos_idcs].view(num_queries, -1, seq_len), value[neg_idcs].view(num_queries, -1, seq_len)], 203 dim=1, 204 ).view(-1, seq_len) 205 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask, BatchEncoding(ib_encoding)) 206
[docs] 207 def validation_step( 208 self, 209 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch, 210 batch_idx: int, 211 dataloader_idx: int = 0, 212 ) -> BiEncoderOutput: 213 """Handles the validation step for the model. 214 215 :param batch: Batch of validation or testing data 216 :type batch: TrainBatch | IndexBatch | SearchBatch | RankBatch 217 :param batch_idx: Index of the batch 218 :type batch_idx: int 219 :param dataloader_idx: Index of the dataloader, defaults to 0 220 :type dataloader_idx: int, optional 221 :return: Model output 222 :rtype: BiEncoderOutput 223 """ 224 if isinstance(batch, IndexBatch): 225 return self.forward(batch) 226 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)): 227 return super().validation_step(batch, batch_idx, dataloader_idx) 228 raise ValueError(f"Unknown batch type {type(batch)}")