ReRankCallback

class lightning_ir.callbacks.callbacks.ReRankCallback(save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False)[source]

Bases: RankCallback

__init__(save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False) None

Callback to write run file of ranked documents to disk.

Parameters:
  • save_dir (Path | str | None, optional) – Directory to save run files to. If None, run files will be saved in the models’ directory, defaults to None

  • run_name (str | None, optional) – Name of the run file. If None, the dataset’s dataset_id or file name will be used, defaults to None

  • overwrite (bool, optional) – Whether to skip or overwrite already existing run files, defaults to False

Methods

__init__([save_dir, run_name, overwrite])

Callback to write run file of ranked documents to disk.

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload callback state given callback's state_dict.

on_after_backward(trainer, pl_module)

Called after loss.backward() and before optimizers are stepped.

on_before_backward(trainer, pl_module, loss)

Called before loss.backward().

on_before_optimizer_step(trainer, pl_module, ...)

Called before optimizer.step().

on_before_zero_grad(trainer, pl_module, ...)

Called before optimizer.zero_grad().

on_exception(trainer, pl_module, exception)

Called when any trainer execution is interrupted by an exception.

on_fit_end(trainer, pl_module)

Called when fit ends.

on_fit_start(trainer, pl_module)

Called when fit begins.

on_load_checkpoint(trainer, pl_module, ...)

Called when loading a model checkpoint, use to reload state.

on_predict_batch_end(trainer, pl_module, ...)

Called when the predict batch ends.

on_predict_batch_start(trainer, pl_module, ...)

Called when the predict batch begins.

on_predict_end(trainer, pl_module)

Called when predict ends.

on_predict_epoch_end(trainer, pl_module)

Called when the predict epoch ends.

on_predict_epoch_start(trainer, pl_module)

Called when the predict epoch begins.

on_predict_start(trainer, pl_module)

Called when the predict begins.

on_sanity_check_end(trainer, pl_module)

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)

Called when the validation sanity check starts.

on_save_checkpoint(trainer, pl_module, ...)

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

on_test_batch_end(trainer, pl_module, ...[, ...])

Hook to aggregate and write ranking to file.

on_test_batch_start(trainer, pl_module, ...)

Called when the test batch begins.

on_test_end(trainer, pl_module)

Called when the test ends.

on_test_epoch_end(trainer, pl_module)

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)

Called when the test epoch begins.

on_test_start(trainer, pl_module)

Called when the test begins.

on_train_batch_end(trainer, pl_module, ...)

Called when the train batch ends.

on_train_batch_start(trainer, pl_module, ...)

Called when the train batch begins.

on_train_end(trainer, pl_module)

Called when the train ends.

on_train_epoch_end(trainer, pl_module)

Called when the train epoch ends.

on_train_epoch_start(trainer, pl_module)

Called when the train epoch begins.

on_train_start(trainer, pl_module)

Called when the train begins.

on_validation_batch_end(trainer, pl_module, ...)

Called when the validation batch ends.

on_validation_batch_start(trainer, ...[, ...])

Called when the validation batch begins.

on_validation_end(trainer, pl_module)

Called when the validation loop ends.

on_validation_epoch_end(trainer, pl_module)

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)

Called when the val epoch begins.

on_validation_start(trainer, pl_module)

Called when the validation loop begins.

setup(trainer, pl_module, stage)

Hook to setup the callback.

state_dict()

Called when saving a checkpoint, implement to generate callback's state_dict.

teardown(trainer, pl_module, stage)

Called when fit, validate, test, predict, or tune ends.

Attributes

state_key

Identifier for the state of the callback.

load_state_dict(state_dict: dict[str, Any]) None

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict – the callback state returned by state_dict.

on_after_backward(trainer: Trainer, pl_module: LightningModule) None

Called after loss.backward() and before optimizers are stepped.

on_before_backward(trainer: Trainer, pl_module: LightningModule, loss: Tensor) None

Called before loss.backward().

on_before_optimizer_step(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None

Called before optimizer.step().

on_before_zero_grad(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None

Called before optimizer.zero_grad().

on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None

Called when any trainer execution is interrupted by an exception.

on_fit_end(trainer: Trainer, pl_module: LightningModule) None

Called when fit ends.

on_fit_start(trainer: Trainer, pl_module: LightningModule) None

Called when fit begins.

on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called when the predict batch ends.

on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called when the predict batch begins.

on_predict_end(trainer: Trainer, pl_module: LightningModule) None

Called when predict ends.

on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None

Called when the predict epoch ends.

on_predict_epoch_start(trainer: Trainer, pl_module: LightningModule) None

Called when the predict epoch begins.

on_predict_start(trainer: Trainer, pl_module: LightningModule) None

Called when the predict begins.

on_sanity_check_end(trainer: Trainer, pl_module: LightningModule) None

Called when the validation sanity check ends.

on_sanity_check_start(trainer: Trainer, pl_module: LightningModule) None

Called when the validation sanity check starts.

on_save_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the checkpoint dictionary that will be saved.

on_test_batch_end(trainer: Trainer, pl_module: LightningIRModule, outputs: LightningIROutput, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Hook to aggregate and write ranking to file.

Parameters:
  • trainer (Trainer) – PyTorch Lightning Trainer

  • pl_module (LightningIRModule) – LightningIR Module

  • outputs (LightningIROutput) – Scored query documents pairs

  • batch (Any) – Batch of input data

  • batch_idx (int) – Index of batch in the current dataset

  • dataloader_idx (int, optional) – Index of the dataloader, defaults to 0

on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called when the test batch begins.

on_test_end(trainer: Trainer, pl_module: LightningModule) None

Called when the test ends.

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None

Called when the test epoch ends.

on_test_epoch_start(trainer: Trainer, pl_module: LightningModule) None

Called when the test epoch begins.

on_test_start(trainer: Trainer, pl_module: LightningModule) None

Called when the test begins.

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int) None

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None

Called when the train batch begins.

on_train_end(trainer: Trainer, pl_module: LightningModule) None

Called when the train ends.

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None

Called when the train epoch begins.

on_train_start(trainer: Trainer, pl_module: LightningModule) None

Called when the train begins.

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called when the validation batch ends.

on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called when the validation batch begins.

on_validation_end(trainer: Trainer, pl_module: LightningModule) None

Called when the validation loop ends.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None

Called when the val epoch ends.

on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None

Called when the val epoch begins.

on_validation_start(trainer: Trainer, pl_module: LightningModule) None

Called when the validation loop begins.

setup(trainer: Trainer, pl_module: LightningIRModule, stage: str) None

Hook to setup the callback.

Parameters:
  • trainer (Trainer) – PyTorch Lightning Trainer

  • pl_module (LightningIRModule) – LightningIR module

  • stage (str) – Stage of the trainer, must be “test”

Raises:
  • ValueError – If the stage is not “test”

  • ValueError – If no save_dir is provided and model_name_or_path is not a path (the model is not local)

state_dict() dict[str, Any]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Returns:

A dictionary containing callback state.

property state_key: str

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None

Called when fit, validate, test, predict, or tune ends.