Source code for lightning_ir.cross_encoder.module

 1"""
 2Module module for cross-encoder models.
 3
 4This module defines the Lightning IR module class used to implement cross-encoder models.
 5"""
 6
 7from typing import List, Sequence, Tuple
 8
 9import torch
10
11from ..base.module import LightningIRModule
12from ..data import RankBatch, SearchBatch, TrainBatch
13from ..loss.loss import LossFunction, ScoringLossFunction
14from .config import CrossEncoderConfig
15from .model import CrossEncoderModel, CrossEncoderOutput
16from .tokenizer import CrossEncoderTokenizer
17
18
[docs] 19class CrossEncoderModule(LightningIRModule):
[docs] 20 def __init__( 21 self, 22 model_name_or_path: str | None = None, 23 config: CrossEncoderConfig | None = None, 24 model: CrossEncoderModel | None = None, 25 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 26 evaluation_metrics: Sequence[str] | None = None, 27 ): 28 """:class:`.LightningIRModule` for cross-encoder models. It contains a :class:`.CrossEncoderModel` and a 29 :class:`.CrossEncoderTokenizer` and implements the training, validation, and testing steps for the model. 30 31 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None 32 :type model_name_or_path: str | None, optional 33 :param config: CrossEncoderConfig to apply when loading from backbone model, defaults to None 34 :type config: CrossEncoderConfig | None, optional 35 :param model: Already instantiated CrossEncoderModel, defaults to None 36 :type model: CrossEncoderModel | None, optional 37 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per 38 loss function, defaults to None 39 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional 40 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or 41 testing, defaults to None 42 """ 43 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics) 44 self.model: CrossEncoderModel 45 self.config: CrossEncoderConfig 46 self.tokenizer: CrossEncoderTokenizer
47
[docs] 48 def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput: 49 """Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the 50 backbone model as well as the relevance scores. 51 52 :param batch: Batch of data to run the forward pass on 53 :type batch: RankBatch | TrainBatch | SearchBatch 54 :raises ValueError: If the batch is a SearchBatch 55 :return: Output of the model 56 :rtype: CrossEncoderOutput 57 """ 58 if isinstance(batch, SearchBatch): 59 raise ValueError("Searching is not available for cross-encoders") 60 queries = batch.queries 61 docs = [d for docs in batch.docs for d in docs] 62 num_docs = [len(docs) for docs in batch.docs] 63 encoding = self.prepare_input(queries, docs, num_docs) 64 output = self.model.forward(encoding["encoding"]) 65 return output
66 67 def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[torch.Tensor]: 68 """Computes the losses for a training batch.""" 69 if self.loss_functions is None: 70 raise ValueError("loss_functions must be set in the module") 71 output = self.forward(batch) 72 if output.scores is None or batch.targets is None: 73 raise ValueError("scores and targets must be set in the output and batch") 74 75 output.scores = output.scores.view(len(batch.query_ids), -1) 76 batch.targets = batch.targets.view(*output.scores.shape, -1) 77 78 losses = [] 79 for loss_function, _ in self.loss_functions: 80 if not isinstance(loss_function, ScoringLossFunction): 81 raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function") 82 losses.append(loss_function.compute_loss(output, batch)) 83 return losses