Trainer
The LightningIRTrainer
derives from a PyTorch Lightning Trainer to enable easy, scalable, and reproducible fine-tuning. It furthermore adds functionality for additional information retrieval stages—namely, indexing, searching, and re-ranking. The trainer combines a LightningIRModule
and a LightningIRDataModule
and handles the fine-tuning and inference logic.
The following sections provide an overview of the trainer’s functionality for the different stages and how to use it.
Note
Lightning IR provides an easy-to-use CLI (based off the Pytorch Lightning CLI) that wraps the trainer and provides commands for all the retrieval stages exemplified below. See the Quickstart Guide for usage examples.
Fine-Tuning
The LightningIRTrainer
is designed for scalable and efficient fine-tuning of models. See the documentation of the parent PyTorch Lightning Trainer for details on hyperparameters, parallelization, logging, checkpointing, and more.
The fit
method of the LightningIRTrainer
starts the fine-tuning process. A LightningIRModule
, a LightningIRDataModule
, an Optimizer, and at least one LossFunction
must first be configured. Several popular loss functions for fine-tuning ranking models are available in the loss
module. The snippet below demonstrates how to fine-tune a BiEncoderModel
or a CrossEncoderModel
on the official MS MARCO triples.
from torch.optim import AdamW
from lightning_ir import (
BiEncoderConfig,
BiEncoderModule,
LightningIRDataModule,
LightningIRTrainer,
RankNet,
TupleDataset,
)
# Define the model
module = BiEncoderModule(
model_name_or_path="bert-base-uncased", # backbone model
config=BiEncoderConfig(),
loss_functions=[RankNet()], # or other loss functions
)
# or
# module = CrossEncoderModule(
# model_name_or_path="bert-base-uncased", # backbone model
# config=CrossEncoderConfig()
# loss_functions=[RankNet()] # or other loss functions
# )
module.set_optimizer(AdamW, lr=1e-5)
# Define the data module
data_module = LightningIRDataModule(
train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
train_batch_size=32,
)
# Define the trainer
trainer = LightningIRTrainer(max_steps=100_000)
# Fine-tune the model
trainer.fit(module, data_module)
Indexing
To index a document collection using an already fine-tuned lightning_ir.bi_encoder.model.BiEncoderModel
use the index
method of the LightningIRTrainer
. The trainer must receive IndexCallback
which handles writing the document embeddings to disk. The IndexCallback
is configured with a IndexConfig
that specifies the type of index to use and how this index should be configured. If the selected lightning_ir.bi_encoder.model.BiEncoderModel
generates sparse embeddings, a SparseIndexConfig
should be used. For dense embeddings, Lightning IR provides a lightning_ir.retrieve.faiss_indexer.FaissIndexConfig
that uses faiss for fast approximate nearest neighbor search. The snippet below demonstrates how to index the MS MARCO passage ranking dataset using an already fine-tuned bi-encoder.
from lightning_ir import (
BiEncoderModule,
DocDataset,
FaissFlatIndexConfig,
IndexCallback,
LightningIRDataModule,
LightningIRTrainer,
)
# Define the model
module = BiEncoderModule(
model_name_or_path="webis/bert-bi-encoder",
)
# Define the data module
data_module = LightningIRDataModule(
inference_datasets=[DocDataset("msmarco-passage")],
inference_batch_size=256,
)
# Define the index callback
callback = IndexCallback(
index_dir="./msmarco-passage-index",
index_config=FaissFlatIndexConfig(),
)
# Define the trainer
trainer = LightningIRTrainer(callbacks=[callback])
# Index the data
trainer.index(module, data_module)
Searching
To search for relevant documents given a query using an already fine-tuned lightning_ir.bi_encoder.model.BiEncoderModel
, use the search
method of the LightningIRTrainer
. The trainer must receive a SearchCallback
which handles loading index and searching for relevant documents based on the generated query embeddings. The SearchCallback
is configured with a SearchConfig
that must match the index configuration used during indexing. To save the results to disk in the form of a run file, you can optionally add a lightning_ir.lightning_utils.callbacks.RankCallback
and specify a directory to save run files to. The snippet below demonstrates how to search for relevant documents given a query using an already fine-tuned bi-encoder.
from lightning_ir import (
BiEncoderModule,
FaissSearchConfig,
LightningIRDataModule,
LightningIRTrainer,
QueryDataset,
SearchCallback,
)
# Define the model
module = BiEncoderModule(
model_name_or_path="webis/bert-bi-encoder",
evaluation_metrics=["nDCG@10"],
)
# Define the data module
data_module = LightningIRDataModule(
inference_datasets=[
QueryDataset("msmarco-passage/trec-dl-2019/judged"),
QueryDataset("msmarco-passage/trec-dl-2020/judged"),
],
inference_batch_size=4,
)
# Define the search callback
callback = SearchCallback(
index_dir="./msmarco-passage-index",
search_config=FaissSearchConfig(k=100),
save_dir="./runs",
)
# Define the trainer
trainer = LightningIRTrainer(callbacks=[callback])
# Retrieve relevant documents
trainer.search(module, data_module)
Re-Ranking
To re-rank a set of retrieved documents using an already fine-tuned BiEncoderModel
or CrossEncoderModel
(the latter are usually more effective), use the re_rank
method of the LightningIRTrainer
. The trainer must receive a ReRankCallback
which handles saving the re-ranked run file to disk. The snippet below demonstrates how to re-rank a set of retrieved documents using an already fine-tuned cross-encoder.
from lightning_ir import CrossEncoderModule, LightningIRDataModule, LightningIRTrainer, ReRankCallback, RunDataset
# Define the model
module = CrossEncoderModule(
model_name_or_path="webis/monoelectra-base",
evaluation_metrics=["nDCG@10"],
)
# Define the data module
data_module = LightningIRDataModule(
inference_datasets=[
RunDataset("./runs/msmarco-passage-trec-dl-2019-judged.run"),
RunDataset("./runs/msmarco-passage-trec-dl-2020-judged.run"),
],
inference_batch_size=4,
)
# Define the search callback
callback = ReRankCallback(save_dir="./re-ranked-runs")
# Define the trainer
trainer = LightningIRTrainer(callbacks=[callback])
# Retrieve relevant documents
trainer.re_rank(module, data_module)