Trainer

The LightningIRTrainer derives from a PyTorch Lightning Trainer to enable easy, scalable, and reproducible fine-tuning. It furthermore adds functionality for additional information retrieval stages—namely, indexing, searching, and re-ranking. The trainer combines a LightningIRModule and a LightningIRDataModule and handles the fine-tuning and inference logic.

The following sections provide an overview of the trainer’s functionality for the different stages and how to use it.

Note

Lightning IR provides an easy-to-use CLI (based off the Pytorch Lightning CLI) that wraps the trainer and provides commands for all the retrieval stages exemplified below. See the Quickstart Guide for usage examples.

Fine-Tuning

The LightningIRTrainer is designed for scalable and efficient fine-tuning of models. See the documentation of the parent PyTorch Lightning Trainer for details on hyperparameters, parallelization, logging, checkpointing, and more.

The fit method of the LightningIRTrainer starts the fine-tuning process. A LightningIRModule, a LightningIRDataModule, an Optimizer, and at least one LossFunction must first be configured. Several popular loss functions for fine-tuning ranking models are available in the loss module. The snippet below demonstrates how to fine-tune a BiEncoderModel or a CrossEncoderModel on the official MS MARCO triples.

from torch.optim import AdamW

from lightning_ir import (
    BiEncoderConfig,
    BiEncoderModule,
    LightningIRDataModule,
    LightningIRTrainer,
    RankNet,
    TupleDataset,
)

# Define the model
module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",  # backbone model
    config=BiEncoderConfig(),
    loss_functions=[RankNet()],  # or other loss functions
)
# or
# module = CrossEncoderModule(
#    model_name_or_path="bert-base-uncased", # backbone model
#    config=CrossEncoderConfig()
#    loss_functions=[RankNet()] # or other loss functions
# )
module.set_optimizer(AdamW, lr=1e-5)

# Define the data module
data_module = LightningIRDataModule(
    train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
    train_batch_size=32,
)

# Define the trainer
trainer = LightningIRTrainer(max_steps=100_000)

# Fine-tune the model
trainer.fit(module, data_module)

Indexing

To index a document collection using an already fine-tuned lightning_ir.bi_encoder.model.BiEncoderModel use the index method of the LightningIRTrainer. The trainer must receive IndexCallback which handles writing the document embeddings to disk. The IndexCallback is configured with a IndexConfig that specifies the type of index to use and how this index should be configured. If the selected lightning_ir.bi_encoder.model.BiEncoderModel generates sparse embeddings, a SparseIndexConfig should be used. For dense embeddings, Lightning IR provides a lightning_ir.retrieve.faiss_indexer.FaissIndexConfig that uses faiss for fast approximate nearest neighbor search. The snippet below demonstrates how to index the MS MARCO passage ranking dataset using an already fine-tuned bi-encoder.

from lightning_ir import (
    BiEncoderModule,
    DocDataset,
    FaissFlatIndexConfig,
    IndexCallback,
    LightningIRDataModule,
    LightningIRTrainer,
)

# Define the model
module = BiEncoderModule(
    model_name_or_path="webis/bert-bi-encoder",
)

# Define the data module
data_module = LightningIRDataModule(
    inference_datasets=[DocDataset("msmarco-passage")],
    inference_batch_size=256,
)

# Define the index callback
callback = IndexCallback(
    index_dir="./msmarco-passage-index",
    index_config=FaissFlatIndexConfig(),
)

# Define the trainer
trainer = LightningIRTrainer(callbacks=[callback])

# Index the data
trainer.index(module, data_module)

Searching

To search for relevant documents given a query using an already fine-tuned lightning_ir.bi_encoder.model.BiEncoderModel, use the search method of the LightningIRTrainer. The trainer must receive a SearchCallback which handles loading index and searching for relevant documents based on the generated query embeddings. The SearchCallback is configured with a SearchConfig that must match the index configuration used during indexing. To save the results to disk in the form of a run file, you can optionally add a lightning_ir.lightning_utils.callbacks.RankCallback and specify a directory to save run files to. The snippet below demonstrates how to search for relevant documents given a query using an already fine-tuned bi-encoder.

from lightning_ir import (
    BiEncoderModule,
    FaissSearchConfig,
    LightningIRDataModule,
    LightningIRTrainer,
    QueryDataset,
    SearchCallback,
)

# Define the model
module = BiEncoderModule(
    model_name_or_path="webis/bert-bi-encoder",
    evaluation_metrics=["nDCG@10"],
)

# Define the data module
data_module = LightningIRDataModule(
    inference_datasets=[
        QueryDataset("msmarco-passage/trec-dl-2019/judged"),
        QueryDataset("msmarco-passage/trec-dl-2020/judged"),
    ],
    inference_batch_size=4,
)

# Define the search callback
callback = SearchCallback(
    index_dir="./msmarco-passage-index",
    search_config=FaissSearchConfig(k=100),
    save_dir="./runs",
)

# Define the trainer
trainer = LightningIRTrainer(callbacks=[callback])

# Retrieve relevant documents
trainer.search(module, data_module)

Re-Ranking

To re-rank a set of retrieved documents using an already fine-tuned BiEncoderModel or CrossEncoderModel (the latter are usually more effective), use the re_rank method of the LightningIRTrainer. The trainer must receive a ReRankCallback which handles saving the re-ranked run file to disk. The snippet below demonstrates how to re-rank a set of retrieved documents using an already fine-tuned cross-encoder.

from lightning_ir import CrossEncoderModule, LightningIRDataModule, LightningIRTrainer, ReRankCallback, RunDataset

# Define the model
module = CrossEncoderModule(
    model_name_or_path="webis/monoelectra-base",
    evaluation_metrics=["nDCG@10"],
)

# Define the data module
data_module = LightningIRDataModule(
    inference_datasets=[
        RunDataset("./runs/msmarco-passage-trec-dl-2019-judged.run"),
        RunDataset("./runs/msmarco-passage-trec-dl-2020-judged.run"),
    ],
    inference_batch_size=4,
)

# Define the search callback
callback = ReRankCallback(save_dir="./re-ranked-runs")

# Define the trainer
trainer = LightningIRTrainer(callbacks=[callback])

# Retrieve relevant documents
trainer.re_rank(module, data_module)