1"""LightningModule for Lightning IR.
2
3This module contains the main module class deriving from a LightningModule_.
4
5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
6"""
7
8from collections import defaultdict
9from pathlib import Path
10from typing import Any, Dict, List, Sequence, Tuple, Type
11
12import torch
13from lightning import LightningModule
14from transformers import BatchEncoding
15
16from ..data import RankBatch, SearchBatch, TrainBatch
17from ..loss.loss import InBatchLossFunction, LossFunction
18from .config import LightningIRConfig
19from .model import LightningIRModel, LightningIROutput
20from .tokenizer import LightningIRTokenizer
21from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
22
23
[docs]
24class LightningIRModule(LightningModule):
25 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a
26 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the
27 model. Derived classes must implement the forward method for the model.
28
29 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
30 """
31
[docs]
32 def __init__(
33 self,
34 model_name_or_path: str | None = None,
35 config: LightningIRConfig | None = None,
36 model: LightningIRModel | None = None,
37 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None,
38 evaluation_metrics: Sequence[str] | None = None,
39 ):
40 """Initializes the LightningIRModule.
41
42 .. _ir-measures: https://ir-measur.es/en/latest/index.html
43
44 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
45 :type model_name_or_path: str | None, optional
46 :param config: LightningIRConfig to apply when loading from backbone model, defaults to None
47 :type config: LightningIRConfig | None, optional
48 :param model: Already instantiated Lightning IR model, defaults to None
49 :type model: LightningIRModel | None, optional
50 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
51 loss function, defaults to None
52 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional
53 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or
54 testing, defaults to None
55 :type evaluation_metrics: Sequence[str] | None, optional
56 :raises ValueError: If both model and model_name_or_path are provided
57 :raises ValueError: If neither model nor model_name_or_path are provided
58 """
59 super().__init__()
60 self.save_hyperparameters()
61 if model is not None and model_name_or_path is not None:
62 raise ValueError("Only one of model or model_name_or_path must be provided.")
63 if model is None:
64 if model_name_or_path is None:
65 raise ValueError("Either model or model_name_or_path must be provided.")
66 model = LightningIRModel.from_pretrained(model_name_or_path, config=config)
67
68 self.model: LightningIRModel = model
69 self.config = self.model.config
70 self.loss_functions: List[Tuple[LossFunction, float]] | None = None
71 if loss_functions is not None:
72 self.loss_functions = []
73 for loss_function in loss_functions:
74 if isinstance(loss_function, LossFunction):
75 self.loss_functions.append((loss_function, 1.0))
76 else:
77 self.loss_functions.append(loss_function)
78 self.evaluation_metrics = evaluation_metrics
79 self._optimizer: torch.optim.Optimizer | None = None
80 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)
81
[docs]
82 def on_train_start(self) -> None:
83 """Called at the beginning of training after sanity check."""
84 super().on_train_start()
85 # NOTE huggingface models are in eval mode by default
86 self.model = self.model.train()
87
99
[docs]
100 def set_optimizer(
101 self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any]
102 ) -> "LightningIRModule":
103 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
104
105 :param optimizer: Torch optimizer class
106 :type optimizer: Type[torch.optim.Optimizer]
107 :param optimizer_kwargs: Arguments to initialize the optimizer
108 :type optimizer_kwargs: Dict[str, Any]
109 :return: self
110 :rtype: LightningIRModule
111 """
112 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs)
113 return self
114
[docs]
115 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput:
116 """Computes relevance scores for queries and documents.
117
118 :param queries: Queries to score
119 :type queries: Sequence[str]
120 :param docs: Documents to score
121 :type docs: Sequence[Sequence[str]]
122 :return: Model output
123 :rtype: LightningIROutput
124 """
125 if isinstance(queries, str):
126 queries = (queries,)
127 if isinstance(docs[0], str):
128 docs = (docs,)
129 batch = RankBatch(queries, docs, None, None)
130 with torch.no_grad():
131 return self.forward(batch)
132
[docs]
133 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput:
134 """Handles the forward pass of the model.
135
136 :param batch: Batch of training or ranking data
137 :type batch: TrainBatch | RankBatch
138 :raises NotImplementedError: Must be implemented by derived class
139 :return: Model output
140 :rtype: LightningIROutput
141 """
142 raise NotImplementedError
143
167
168 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]:
169 """Computes the losses for a training batch."""
170 raise NotImplementedError
171
[docs]
172 def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
173 """Handles the training step for the model.
174
175 :param batch: Batch of training data
176 :type batch: TrainBatch
177 :param batch_idx: Index of the batch
178 :type batch_idx: int
179 :raises ValueError: If no loss functions are set
180 :return: Sum of the losses weighted by the loss weights
181 :rtype: torch.Tensor
182 """
183 if self.loss_functions is None:
184 raise ValueError("Loss functions are not set")
185 output = self.forward(batch)
186 losses = self._compute_losses(batch, output)
187 total_loss = torch.tensor(0)
188 assert len(losses) == len(self.loss_functions)
189 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses):
190 self.log(loss_function.__class__.__name__, loss)
191 total_loss = total_loss + loss * loss_weight
192 self.log("loss", total_loss, prog_bar=True)
193 return total_loss
194
[docs]
195 def validation_step(
196 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0
197 ) -> LightningIROutput:
198 """Handles the validation step for the model.
199
200 :param batch: Batch of validation or testing data
201 :type batch: TrainBatch | RankBatch | SearchBatch
202 :param batch_idx: Index of the batch
203 :type batch_idx: int
204 :param dataloader_idx: Index of the dataloader, defaults to 0
205 :type dataloader_idx: int, optional
206 :return: Model output
207 :rtype: LightningIROutput
208 """
209 output = self.forward(batch)
210
211 if self.evaluation_metrics is None:
212 return output
213
214 dataset_id = self.get_dataset_id(dataloader_idx)
215 metrics = self.validate(output, batch)
216 for key, value in metrics.items():
217 key = f"{dataset_id}/{key}"
218 self.log(key, value, batch_size=len(batch.queries))
219 return output
220
[docs]
221 def test_step(
222 self,
223 batch: TrainBatch | RankBatch,
224 batch_idx: int,
225 dataloader_idx: int = 0,
226 ) -> LightningIROutput:
227 """Handles the testing step for the model. Passes the batch to the validation step.
228
229 :param batch: Batch of testing data
230 :type batch: TrainBatch | RankBatch
231 :param batch_idx: Index of the batch
232 :type batch_idx: int
233 :param dataloader_idx: Index of the dataloader, defaults to 0
234 :type dataloader_idx: int, optional
235 :return: Model output
236 :rtype: LightningIROutput
237 """
238 return self.validation_step(batch, batch_idx, dataloader_idx)
239
[docs]
240 def get_dataset_id(self, dataloader_idx: int) -> str:
241 """Gets the dataset id from the dataloader index for logging.
242
243 .. _ir-datasets: https://ir-datasets.com/
244
245 :param dataloader_idx: Index of the dataloader
246 :type dataloader_idx: int
247 :return: ir-datasets_ dataset id or dataloader index
248 :rtype: str
249 """
250 dataset_id = str(dataloader_idx)
251 datamodule = None
252 try:
253 datamodule = getattr(self.trainer, "datamodule", None)
254 dataset_id = datamodule.inference_datasets[dataloader_idx].dataset_id
255 except Exception:
256 pass
257 return dataset_id
258
[docs]
259 def validate(
260 self,
261 output: LightningIROutput,
262 batch: TrainBatch | RankBatch | SearchBatch,
263 ) -> Dict[str, float]:
264 """Validates the model output with the evaluation metrics and loss functions.
265
266 :param output: Model output
267 :type output: LightningIROutput
268 :param batch: Batch of validation or testing data
269 :type batch: TrainBatch | RankBatch | SearchBatch
270 :return: Dictionary of evaluation metrics
271 :rtype: Dict[str, float]
272 """
273 metrics: Dict[str, float] = {}
274 if self.evaluation_metrics is None or output.scores is None:
275 return metrics
276 metrics.update(self.validate_metrics(output, batch))
277 metrics.update(self.validate_loss(output, batch))
278 return metrics
279
[docs]
280 def validate_metrics(
281 self,
282 output: LightningIROutput,
283 batch: TrainBatch | RankBatch | SearchBatch,
284 ) -> Dict[str, float]:
285 """Validates the model output with the evaluation metrics.
286
287 :param output: Model output
288 :type output: LightningIROutput
289 :param batch: Batch of validation or testing data
290 :type batch: TrainBatch | RankBatch | SearchBatch
291 :return: Evaluation metrics
292 :rtype: Dict[str, float]
293 """
294 metrics: Dict[str, float] = {}
295 qrels = batch.qrels
296 if self.evaluation_metrics is None or qrels is None:
297 return metrics
298 query_ids = batch.query_ids
299 doc_ids = batch.doc_ids
300 if query_ids is None:
301 raise ValueError("query_ids must be set")
302 if doc_ids is None:
303 raise ValueError("doc_ids must be set")
304 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"]
305 ir_measures_qrels = create_qrels_from_dicts(qrels)
306 if evaluation_metrics and qrels is not None and output.scores is not None:
307 run = create_run_from_scores(query_ids, doc_ids, output.scores)
308 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics))
309 return metrics
310
[docs]
311 def validate_loss(
312 self,
313 output: LightningIROutput,
314 batch: TrainBatch | RankBatch | SearchBatch,
315 ) -> Dict[str, float]:
316 """Validates the model output with the loss functions.
317
318 :param output: Model output
319 :type output: LightningIROutput
320 :param batch: Batch of validation or testing data
321 :type batch: TrainBatch | RankBatch | SearchBatch
322 :return: Evaluation metrics
323 :rtype: Dict[str, float]
324 """
325 metrics: Dict[str, float] = {}
326 query_ids = batch.query_ids
327 if query_ids is None:
328 raise ValueError("query_ids must be set")
329 if (
330 self.evaluation_metrics is None
331 or "loss" not in self.evaluation_metrics
332 or getattr(batch, "targets", None) is None
333 or self.loss_functions is None
334 or output.scores is None
335 ):
336 return metrics
337 output.scores = output.scores.view(len(query_ids), -1)
338 for loss_function, _ in self.loss_functions:
339 # NOTE skip in-batch losses because they can use a lot of memory
340 if isinstance(loss_function, InBatchLossFunction):
341 continue
342 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch).item()
343 return metrics
344
[docs]
345 def on_validation_epoch_end(self) -> None:
346 """Logs the accumulated metrics for each dataloader."""
347 try:
348 trainer = self.trainer
349 except RuntimeError:
350 trainer = None
351 if trainer is not None:
352 metrics = trainer.callback_metrics
353 accum_metrics = defaultdict(list)
354 for key, value in metrics.items():
355 split = key.split("/")
356 if "dataloader_idx" in split[-1]:
357 accum_metrics[split[-2]].append(value)
358 for key, value in accum_metrics.items():
359 self.log(key, torch.stack(value).mean(), logger=False)
360
[docs]
361 def on_test_epoch_end(self) -> None:
362 """Logs the accumulated metrics for each dataloader."""
363 self.on_validation_epoch_end()
364
[docs]
365 def save_pretrained(self, save_path: str | Path) -> None:
366 """Saves the model and tokenizer to the save path.
367
368 :param save_path: Path to save the model and tokenizer
369 :type save_path: str | Path
370 """
371 self.model.save_pretrained(save_path)
372 self.tokenizer.save_pretrained(save_path)
373
[docs]
374 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
375 """Saves the model and tokenizer to the trainer's log directory."""
376 if self.trainer is not None and self.trainer.log_dir is not None:
377 if self.trainer.global_rank != 0:
378 return
379 _step = self.trainer.global_step
380 self.config.save_step = _step
381 log_dir = Path(self.trainer.log_dir)
382 save_path = log_dir / "huggingface_checkpoint"
383 self.save_pretrained(save_path)