SearchCallback
- class lightning_ir.callbacks.callbacks.SearchCallback(search_config: SearchConfig, index_dir: Path | str | None = None, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False, use_gpu: bool = True)[source]
Bases:
RankCallback
,_IndexDirMixin
- __init__(search_config: SearchConfig, index_dir: Path | str | None = None, save_dir: Path | str | None = None, run_name: str | None = None, overwrite: bool = False, use_gpu: bool = True) None [source]
Callback to which uses index to retrieve documents efficiently.
- Parameters:
search_config (SearchConfig) – Configuration of the
Searcher
index_dir (Path | str | None, optional) – Directory where indexes are stored, defaults to None
save_dir (Path | str | None, optional) – Directory to save run files to. If None, run files are saved in the model’s directory, defaults to None
run_name (str | None, optional) – Name of the run file. If None, the dataset’s dataset_id or file name will be used, defaults to None
overwrite (bool, optional) – Whether to skip or overwrite already existing run files, defaults to False
use_gpu (bool, optional) – Toggle to use gpu for retrieval, defaults to True
Methods
__init__
(search_config[, index_dir, ...])Callback to which uses index to retrieve documents efficiently.
load_state_dict
(state_dict)Called when loading a checkpoint, implement to reload callback state given callback's
state_dict
.on_after_backward
(trainer, pl_module)Called after
loss.backward()
and before optimizers are stepped.on_before_backward
(trainer, pl_module, loss)Called before
loss.backward()
.on_before_optimizer_step
(trainer, pl_module, ...)Called before
optimizer.step()
.on_before_zero_grad
(trainer, pl_module, ...)Called before
optimizer.zero_grad()
.on_exception
(trainer, pl_module, exception)Called when any trainer execution is interrupted by an exception.
on_fit_end
(trainer, pl_module)Called when fit ends.
on_fit_start
(trainer, pl_module)Called when fit begins.
on_load_checkpoint
(trainer, pl_module, ...)Called when loading a model checkpoint, use to reload state.
on_predict_batch_end
(trainer, pl_module, ...)Called when the predict batch ends.
on_predict_batch_start
(trainer, pl_module, ...)Called when the predict batch begins.
on_predict_end
(trainer, pl_module)Called when predict ends.
on_predict_epoch_end
(trainer, pl_module)Called when the predict epoch ends.
on_predict_epoch_start
(trainer, pl_module)Called when the predict epoch begins.
on_predict_start
(trainer, pl_module)Called when the predict begins.
on_sanity_check_end
(trainer, pl_module)Called when the validation sanity check ends.
on_sanity_check_start
(trainer, pl_module)Called when the validation sanity check starts.
on_save_checkpoint
(trainer, pl_module, ...)Called when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end
(trainer, pl_module, ...[, ...])Hook to aggregate and write ranking to file.
on_test_batch_start
(trainer, pl_module, ...)Hook to initialize searcher for new datasets.
on_test_end
(trainer, pl_module)Called when the test ends.
on_test_epoch_end
(trainer, pl_module)Called when the test epoch ends.
on_test_epoch_start
(trainer, pl_module)Called when the test epoch begins.
on_test_start
(trainer, pl_module)Hook to validate datasets
on_train_batch_end
(trainer, pl_module, ...)Called when the train batch ends.
on_train_batch_start
(trainer, pl_module, ...)Called when the train batch begins.
on_train_end
(trainer, pl_module)Called when the train ends.
on_train_epoch_end
(trainer, pl_module)Called when the train epoch ends.
on_train_epoch_start
(trainer, pl_module)Called when the train epoch begins.
on_train_start
(trainer, pl_module)Called when the train begins.
on_validation_batch_end
(trainer, pl_module, ...)Called when the validation batch ends.
on_validation_batch_start
(trainer, ...[, ...])Called when the validation batch begins.
on_validation_end
(trainer, pl_module)Called when the validation loop ends.
on_validation_epoch_end
(trainer, pl_module)Called when the val epoch ends.
on_validation_epoch_start
(trainer, pl_module)Called when the val epoch begins.
on_validation_start
(trainer, pl_module)Called when the validation loop begins.
setup
(trainer, pl_module, stage)Hook to setup the callback.
Called when saving a checkpoint, implement to generate callback's
state_dict
.teardown
(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune ends.
Attributes
Identifier for the state of the callback.
index_dir
- load_state_dict(state_dict: dict[str, Any]) None
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.- Parameters:
state_dict – the callback state returned by
state_dict
.
- on_after_backward(trainer: Trainer, pl_module: LightningModule) None
Called after
loss.backward()
and before optimizers are stepped.
- on_before_backward(trainer: Trainer, pl_module: LightningModule, loss: Tensor) None
Called before
loss.backward()
.
- on_before_optimizer_step(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None
Called before
optimizer.step()
.
- on_before_zero_grad(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None
Called before
optimizer.zero_grad()
.
- on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None
Called when any trainer execution is interrupted by an exception.
- on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None
Called when the predict batch ends.
- on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None
Called when the predict batch begins.
- on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None
Called when the predict epoch ends.
- on_predict_epoch_start(trainer: Trainer, pl_module: LightningModule) None
Called when the predict epoch begins.
- on_sanity_check_end(trainer: Trainer, pl_module: LightningModule) None
Called when the validation sanity check ends.
- on_sanity_check_start(trainer: Trainer, pl_module: LightningModule) None
Called when the validation sanity check starts.
- on_save_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the checkpoint dictionary that will be saved.
- on_test_batch_end(trainer: Trainer, pl_module: LightningIRModule, outputs: LightningIROutput, batch: Any, batch_idx: int, dataloader_idx: int = 0) None
Hook to aggregate and write ranking to file.
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (LightningIRModule) – LightningIR Module
outputs (LightningIROutput) – Scored query documents pairs
batch (Any) – Batch of input data
batch_idx (int) – Index of batch in the current dataset
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- on_test_batch_start(trainer: Trainer, pl_module: BiEncoderModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None [source]
Hook to initialize searcher for new datasets.
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (BiEncoderModule) – LightningIR BiEncoderModule
batch (Any) – Batch of input data
batch_idx (int) – Index of batch in dataset
dataloader_idx (int, optional) – Index of the dataloader, defaults to 0
- on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None
Called when the test epoch ends.
- on_test_epoch_start(trainer: Trainer, pl_module: LightningModule) None
Called when the test epoch begins.
- on_test_start(trainer: Trainer, pl_module: BiEncoderModule) None [source]
Hook to validate datasets
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (BiEncoderModule) – LightningIR BiEncoderModule
- Raises:
ValueError – If no test_dataloaders are found
ValueError – If not all datasets are
QueryDataset
- on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int) None
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None
Called when the train batch begins.
- on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None
Called when the train epoch begins.
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None
Called when the validation batch ends.
- on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None
Called when the validation batch begins.
- on_validation_end(trainer: Trainer, pl_module: LightningModule) None
Called when the validation loop ends.
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None
Called when the val epoch ends.
- on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None
Called when the val epoch begins.
- on_validation_start(trainer: Trainer, pl_module: LightningModule) None
Called when the validation loop begins.
- setup(trainer: Trainer, pl_module: LightningIRModule, stage: str) None
Hook to setup the callback.
- Parameters:
trainer (Trainer) – PyTorch Lightning Trainer
pl_module (LightningIRModule) – LightningIR module
stage (str) – Stage of the trainer, must be “test”
- Raises:
ValueError – If the stage is not “test”
ValueError – If no save_dir is provided and model_name_or_path is not a path (the model is not local)
- state_dict() dict[str, Any]
Called when saving a checkpoint, implement to generate callback’s
state_dict
.- Returns:
A dictionary containing callback state.
- property state_key: str
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.