Source code for lightning_ir.base.module

  1"""LightningModule for Lightning IR.
  2
  3This module contains the main module class deriving from a LightningModule_.
  4
  5.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
  6"""
  7
  8from collections import defaultdict
  9from pathlib import Path
 10from typing import Any, Dict, List, Sequence, Tuple, Type
 11
 12import torch
 13from lightning import LightningModule
 14from transformers import BatchEncoding
 15
 16from ..data import RankBatch, SearchBatch, TrainBatch
 17from ..loss.loss import InBatchLossFunction, LossFunction
 18from .config import LightningIRConfig
 19from .model import LightningIRModel, LightningIROutput
 20from .tokenizer import LightningIRTokenizer
 21from .validation_utils import create_qrels_from_dicts, create_run_from_scores, evaluate_run
 22
 23
[docs] 24class LightningIRModule(LightningModule): 25 """LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a 26 LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the 27 model. Derived classes must implement the forward method for the model. 28 29 .. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html 30 """ 31
[docs] 32 def __init__( 33 self, 34 model_name_or_path: str | None = None, 35 config: LightningIRConfig | None = None, 36 model: LightningIRModel | None = None, 37 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 38 evaluation_metrics: Sequence[str] | None = None, 39 ): 40 """Initializes the LightningIRModule. 41 42 .. _ir-measures: https://ir-measur.es/en/latest/index.html 43 44 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None 45 :type model_name_or_path: str | None, optional 46 :param config: LightningIRConfig to apply when loading from backbone model, defaults to None 47 :type config: LightningIRConfig | None, optional 48 :param model: Already instantiated Lightning IR model, defaults to None 49 :type model: LightningIRModel | None, optional 50 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per 51 loss function, defaults to None 52 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional 53 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or 54 testing, defaults to None 55 :type evaluation_metrics: Sequence[str] | None, optional 56 :raises ValueError: If both model and model_name_or_path are provided 57 :raises ValueError: If neither model nor model_name_or_path are provided 58 """ 59 super().__init__() 60 self.save_hyperparameters() 61 if model is not None and model_name_or_path is not None: 62 raise ValueError("Only one of model or model_name_or_path must be provided.") 63 if model is None: 64 if model_name_or_path is None: 65 raise ValueError("Either model or model_name_or_path must be provided.") 66 model = LightningIRModel.from_pretrained(model_name_or_path, config=config) 67 68 self.model: LightningIRModel = model 69 self.config = self.model.config 70 self.loss_functions: List[Tuple[LossFunction, float]] | None = None 71 if loss_functions is not None: 72 self.loss_functions = [] 73 for loss_function in loss_functions: 74 if isinstance(loss_function, LossFunction): 75 self.loss_functions.append((loss_function, 1.0)) 76 else: 77 self.loss_functions.append(loss_function) 78 self.evaluation_metrics = evaluation_metrics 79 self._optimizer: torch.optim.Optimizer | None = None 80 self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)
81
[docs] 82 def on_train_start(self) -> None: 83 """Called at the beginning of training after sanity check.""" 84 super().on_train_start() 85 # NOTE huggingface models are in eval mode by default 86 self.model = self.model.train()
87
[docs] 88 def configure_optimizers(self) -> torch.optim.Optimizer: 89 """Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR 90 programmatically, the optimizer must be set using :meth:`set_optimizer`. 91 92 :raises ValueError: If optimizer is not set 93 :return: Optimizer 94 :rtype: torch.optim.Optimizer 95 """ 96 if self._optimizer is None: 97 raise ValueError("Optimizer is not set. Call `set_optimizer`.") 98 return self._optimizer
99
[docs] 100 def set_optimizer( 101 self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any] 102 ) -> "LightningIRModule": 103 """Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI. 104 105 :param optimizer: Torch optimizer class 106 :type optimizer: Type[torch.optim.Optimizer] 107 :param optimizer_kwargs: Arguments to initialize the optimizer 108 :type optimizer_kwargs: Dict[str, Any] 109 :return: self 110 :rtype: LightningIRModule 111 """ 112 self._optimizer = optimizer(self.parameters(), **optimizer_kwargs) 113 return self
114
[docs] 115 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput: 116 """Computes relevance scores for queries and documents. 117 118 :param queries: Queries to score 119 :type queries: Sequence[str] 120 :param docs: Documents to score 121 :type docs: Sequence[Sequence[str]] 122 :return: Model output 123 :rtype: LightningIROutput 124 """ 125 if isinstance(queries, str): 126 queries = (queries,) 127 if isinstance(docs[0], str): 128 docs = (docs,) 129 batch = RankBatch(queries, docs, None, None) 130 with torch.no_grad(): 131 return self.forward(batch)
132
[docs] 133 def forward(self, batch: TrainBatch | RankBatch | SearchBatch) -> LightningIROutput: 134 """Handles the forward pass of the model. 135 136 :param batch: Batch of training or ranking data 137 :type batch: TrainBatch | RankBatch 138 :raises NotImplementedError: Must be implemented by derived class 139 :return: Model output 140 :rtype: LightningIROutput 141 """ 142 raise NotImplementedError
143
[docs] 144 def prepare_input( 145 self, queries: Sequence[str] | None, docs: Sequence[str] | None, num_docs: Sequence[int] | int | None 146 ) -> Dict[str, BatchEncoding]: 147 """Tokenizes queries and documents and returns the tokenized BatchEncoding_. 148 149 :: _BatchEncoding: https://huggingface.co/transformers/main_classes/tokenizer#transformers.BatchEncoding 150 151 :param queries: Queries to tokenize 152 :type queries: Sequence[str] | None 153 :param docs: Documents to tokenize 154 :type docs: Sequence[str] | None 155 :param num_docs: Number of documents per query, if None num_docs is inferred by `len(docs) // len(queries)`, 156 defaults to None 157 :type num_docs: Sequence[int] | int | None 158 :return: Tokenized queries and documents, format depends on the tokenizer 159 :rtype: Dict[str, BatchEncoding] 160 """ 161 encodings = self.tokenizer.tokenize( 162 queries, docs, return_tensors="pt", padding=True, truncation=True, num_docs=num_docs 163 ) 164 for key in encodings: 165 encodings[key] = encodings[key].to(self.device) 166 return encodings
167 168 def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]: 169 """Computes the losses for a training batch.""" 170 raise NotImplementedError 171
[docs] 172 def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor: 173 """Handles the training step for the model. 174 175 :param batch: Batch of training data 176 :type batch: TrainBatch 177 :param batch_idx: Index of the batch 178 :type batch_idx: int 179 :raises ValueError: If no loss functions are set 180 :return: Sum of the losses weighted by the loss weights 181 :rtype: torch.Tensor 182 """ 183 if self.loss_functions is None: 184 raise ValueError("Loss functions are not set") 185 output = self.forward(batch) 186 losses = self._compute_losses(batch, output) 187 total_loss = torch.tensor(0) 188 assert len(losses) == len(self.loss_functions) 189 for (loss_function, loss_weight), loss in zip(self.loss_functions, losses): 190 self.log(loss_function.__class__.__name__, loss) 191 total_loss = total_loss + loss * loss_weight 192 self.log("loss", total_loss, prog_bar=True) 193 return total_loss
194
[docs] 195 def validation_step( 196 self, batch: TrainBatch | RankBatch | SearchBatch, batch_idx: int, dataloader_idx: int = 0 197 ) -> LightningIROutput: 198 """Handles the validation step for the model. 199 200 :param batch: Batch of validation or testing data 201 :type batch: TrainBatch | RankBatch | SearchBatch 202 :param batch_idx: Index of the batch 203 :type batch_idx: int 204 :param dataloader_idx: Index of the dataloader, defaults to 0 205 :type dataloader_idx: int, optional 206 :return: Model output 207 :rtype: LightningIROutput 208 """ 209 output = self.forward(batch) 210 211 if self.evaluation_metrics is None: 212 return output 213 214 dataset_id = self.get_dataset_id(dataloader_idx) 215 metrics = self.validate(output, batch) 216 for key, value in metrics.items(): 217 key = f"{dataset_id}/{key}" 218 self.log(key, value, batch_size=len(batch.queries)) 219 return output
220
[docs] 221 def test_step( 222 self, 223 batch: TrainBatch | RankBatch, 224 batch_idx: int, 225 dataloader_idx: int = 0, 226 ) -> LightningIROutput: 227 """Handles the testing step for the model. Passes the batch to the validation step. 228 229 :param batch: Batch of testing data 230 :type batch: TrainBatch | RankBatch 231 :param batch_idx: Index of the batch 232 :type batch_idx: int 233 :param dataloader_idx: Index of the dataloader, defaults to 0 234 :type dataloader_idx: int, optional 235 :return: Model output 236 :rtype: LightningIROutput 237 """ 238 return self.validation_step(batch, batch_idx, dataloader_idx)
239
[docs] 240 def get_dataset_id(self, dataloader_idx: int) -> str: 241 """Gets the dataset id from the dataloader index for logging. 242 243 .. _ir-datasets: https://ir-datasets.com/ 244 245 :param dataloader_idx: Index of the dataloader 246 :type dataloader_idx: int 247 :return: ir-datasets_ dataset id or dataloader index 248 :rtype: str 249 """ 250 dataset_id = str(dataloader_idx) 251 datamodule = None 252 try: 253 datamodule = getattr(self.trainer, "datamodule", None) 254 dataset_id = datamodule.inference_datasets[dataloader_idx].dataset_id 255 except Exception: 256 pass 257 return dataset_id
258
[docs] 259 def validate( 260 self, 261 output: LightningIROutput, 262 batch: TrainBatch | RankBatch | SearchBatch, 263 ) -> Dict[str, float]: 264 """Validates the model output with the evaluation metrics and loss functions. 265 266 :param output: Model output 267 :type output: LightningIROutput 268 :param batch: Batch of validation or testing data 269 :type batch: TrainBatch | RankBatch | SearchBatch 270 :return: Dictionary of evaluation metrics 271 :rtype: Dict[str, float] 272 """ 273 metrics: Dict[str, float] = {} 274 if self.evaluation_metrics is None or output.scores is None: 275 return metrics 276 metrics.update(self.validate_metrics(output, batch)) 277 metrics.update(self.validate_loss(output, batch)) 278 return metrics
279
[docs] 280 def validate_metrics( 281 self, 282 output: LightningIROutput, 283 batch: TrainBatch | RankBatch | SearchBatch, 284 ) -> Dict[str, float]: 285 """Validates the model output with the evaluation metrics. 286 287 :param output: Model output 288 :type output: LightningIROutput 289 :param batch: Batch of validation or testing data 290 :type batch: TrainBatch | RankBatch | SearchBatch 291 :return: Evaluation metrics 292 :rtype: Dict[str, float] 293 """ 294 metrics: Dict[str, float] = {} 295 qrels = batch.qrels 296 if self.evaluation_metrics is None or qrels is None: 297 return metrics 298 query_ids = batch.query_ids 299 doc_ids = batch.doc_ids 300 if query_ids is None: 301 raise ValueError("query_ids must be set") 302 if doc_ids is None: 303 raise ValueError("doc_ids must be set") 304 evaluation_metrics = [metric for metric in self.evaluation_metrics if metric != "loss"] 305 ir_measures_qrels = create_qrels_from_dicts(qrels) 306 if evaluation_metrics and qrels is not None and output.scores is not None: 307 run = create_run_from_scores(query_ids, doc_ids, output.scores) 308 metrics.update(evaluate_run(run, ir_measures_qrels, evaluation_metrics)) 309 return metrics
310
[docs] 311 def validate_loss( 312 self, 313 output: LightningIROutput, 314 batch: TrainBatch | RankBatch | SearchBatch, 315 ) -> Dict[str, float]: 316 """Validates the model output with the loss functions. 317 318 :param output: Model output 319 :type output: LightningIROutput 320 :param batch: Batch of validation or testing data 321 :type batch: TrainBatch | RankBatch | SearchBatch 322 :return: Evaluation metrics 323 :rtype: Dict[str, float] 324 """ 325 metrics: Dict[str, float] = {} 326 query_ids = batch.query_ids 327 if query_ids is None: 328 raise ValueError("query_ids must be set") 329 if ( 330 self.evaluation_metrics is None 331 or "loss" not in self.evaluation_metrics 332 or getattr(batch, "targets", None) is None 333 or self.loss_functions is None 334 or output.scores is None 335 ): 336 return metrics 337 output.scores = output.scores.view(len(query_ids), -1) 338 for loss_function, _ in self.loss_functions: 339 # NOTE skip in-batch losses because they can use a lot of memory 340 if isinstance(loss_function, InBatchLossFunction): 341 continue 342 metrics[f"validation-{loss_function.__class__.__name__}"] = loss_function.compute_loss(output, batch).item() 343 return metrics
344
[docs] 345 def on_validation_epoch_end(self) -> None: 346 """Logs the accumulated metrics for each dataloader.""" 347 try: 348 trainer = self.trainer 349 except RuntimeError: 350 trainer = None 351 if trainer is not None: 352 metrics = trainer.callback_metrics 353 accum_metrics = defaultdict(list) 354 for key, value in metrics.items(): 355 split = key.split("/") 356 if "dataloader_idx" in split[-1]: 357 accum_metrics[split[-2]].append(value) 358 for key, value in accum_metrics.items(): 359 self.log(key, torch.stack(value).mean(), logger=False)
360
[docs] 361 def on_test_epoch_end(self) -> None: 362 """Logs the accumulated metrics for each dataloader.""" 363 self.on_validation_epoch_end()
364
[docs] 365 def save_pretrained(self, save_path: str | Path) -> None: 366 """Saves the model and tokenizer to the save path. 367 368 :param save_path: Path to save the model and tokenizer 369 :type save_path: str | Path 370 """ 371 self.model.save_pretrained(save_path) 372 self.tokenizer.save_pretrained(save_path)
373
[docs] 374 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 375 """Saves the model and tokenizer to the trainer's log directory.""" 376 if self.trainer is not None and self.trainer.log_dir is not None: 377 if self.trainer.global_rank != 0: 378 return 379 _step = self.trainer.global_step 380 self.config.save_step = _step 381 log_dir = Path(self.trainer.log_dir) 382 save_path = log_dir / "huggingface_checkpoint" 383 self.save_pretrained(save_path)