Source code for lightning_ir.cross_encoder.module
1"""
2Module module for cross-encoder models.
3
4This module defines the Lightning IR module class used to implement cross-encoder models.
5"""
6
7from typing import List, Sequence, Tuple
8
9import torch
10
11from ..base.module import LightningIRModule
12from ..data import RankBatch, SearchBatch, TrainBatch
13from ..loss.loss import LossFunction, ScoringLossFunction
14from .config import CrossEncoderConfig
15from .model import CrossEncoderModel, CrossEncoderOutput
16from .tokenizer import CrossEncoderTokenizer
17
18
[docs]
19class CrossEncoderModule(LightningIRModule):
[docs]
20 def __init__(
21 self,
22 model_name_or_path: str | None = None,
23 config: CrossEncoderConfig | None = None,
24 model: CrossEncoderModel | None = None,
25 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
26 evaluation_metrics: Sequence[str] | None = None,
27 ):
28 """:class:`.LightningIRModule` for cross-encoder models. It contains a :class:`.CrossEncoderModel` and a
29 :class:`.CrossEncoderTokenizer` and implements the training, validation, and testing steps for the model.
30
31 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
32 :type model_name_or_path: str | None, optional
33 :param config: CrossEncoderConfig to apply when loading from backbone model, defaults to None
34 :type config: CrossEncoderConfig | None, optional
35 :param model: Already instantiated CrossEncoderModel, defaults to None
36 :type model: CrossEncoderModel | None, optional
37 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
38 loss function, defaults to None
39 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
40 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
41 testing, defaults to None
42 """
43 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics)
44 self.model: CrossEncoderModel
45 self.config: CrossEncoderConfig
46 self.tokenizer: CrossEncoderTokenizer
47
[docs]
48 def forward(self, batch: RankBatch | TrainBatch | SearchBatch) -> CrossEncoderOutput:
49 """Runs a forward pass of the model on a batch of data and returns the contextualized embeddings from the
50 backbone model as well as the relevance scores.
51
52 :param batch: Batch of data to run the forward pass on
53 :type batch: RankBatch | TrainBatch | SearchBatch
54 :raises ValueError: If the batch is a SearchBatch
55 :return: Output of the model
56 :rtype: CrossEncoderOutput
57 """
58 if isinstance(batch, SearchBatch):
59 raise ValueError("Searching is not available for cross-encoders")
60 queries = batch.queries
61 docs = [d for docs in batch.docs for d in docs]
62 num_docs = [len(docs) for docs in batch.docs]
63 encoding = self.prepare_input(queries, docs, num_docs)
64 output = self.model.forward(encoding["encoding"])
65 return output
66
67 def _compute_losses(self, batch: TrainBatch, output: CrossEncoderOutput) -> List[torch.Tensor]:
68 """Computes the losses for a training batch."""
69 if self.loss_functions is None:
70 raise ValueError("loss_functions must be set in the module")
71 output = self.forward(batch)
72 if output.scores is None or batch.targets is None:
73 raise ValueError("scores and targets must be set in the output and batch")
74
75 output.scores = output.scores.view(len(batch.query_ids), -1)
76 batch.targets = batch.targets.view(*output.scores.shape, -1)
77
78 losses = []
79 for loss_function, _ in self.loss_functions:
80 if not isinstance(loss_function, ScoringLossFunction):
81 raise RuntimeError(f"Loss function {loss_function} is not a scoring loss function")
82 losses.append(loss_function.compute_loss(output, batch))
83 return losses