Build A Custom Model
This section provides step-by-step guides on how to build custom BiEncoderModel
and CrossEncoderModel
models in Lightning IR.
Bi-Encoder
Say we wanted to build a custom bi-encoder model that adds an additional linear layer on top of the pooled embeddings. If we wanted to make this option configurable, we would first need to subclass the BiEncoderConfig
and add a new attribute for the additional linear layer. We must also assign a new model_type
to our model and make sure our additional attributes are included in the ADDED_ARGS
set to ensure saved models are loaded correctly. For example:
from lightning_ir.bi_encoder.config import BiEncoderConfig
class CustomBiEncoderConfig(BiEncoderConfig):
model_type = "custom-bi-encoder"
ADDED_ARGS = BiEncoderConfig.ADDED_ARGS.union({"additional_linear_layer"})
def __init__(self, additional_linear_layer = True, **kwargs):
super().__init__(**kwargs)
self.additional_linear_layer = additional_linear_layer
Next, we need to subclass the lightning_ir.bi_encoder.model.BiEncoderModel
and override the lightning_ir.bi_encoder.model.BiEncoderModel.encode
method to include the additional linear layer. We also need to ensure that our new config class is registered with our new model as the config_class()
attribute. In the lightning_ir.bi_encoder.model.BiEncoderModel.encode
method, the _backbone_forward()
method runs the backbone model and returns the contextualized embeddings of the input sequence. We then apply our additional linear layer to the pooled embeddings. Afterwards, the various steps of the processing pipeline for bi-encoders are applied (see Model for more details). For example:
from typing import Literal
import torch
from transformers import BatchEncoding
from lightning_ir import BiEncoderModel, BiEncoderOutput
class CustomBiEncoderModel(BiEncoderModel):
config_class = CustomBiEncoderConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.additional_linear_layer = None
if config.additional_linear_layer:
self.additional_linear_layer = torch.nn.Linear(
config.hidden_size, config.hidden_size
)
def encode(
self,
encoding: BatchEncoding,
expansion: bool = False,
pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None,
mask_scoring_input_ids: torch.Tensor | None = None,
) -> BiEncoderEmbedding:
embeddings = self._backbone_forward(**encoding).last_hidden_state
if self.additional_linear_layer is not None: # apply additional linear layer
embeddings = self.additional_linear_layer(embeddings)
if self.projection is not None:
embeddings = self.projection(embeddings)
embeddings = self._sparsification(embeddings, self.config.sparsification)
embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy)
if self.config.normalize:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
scoring_mask = self.scoring_mask(
encoding["input_ids"],
encoding["attention_mask"],
expansion,
pooling_strategy,
mask_scoring_input_ids,
)
return BiEncoderEmbedding(embeddings, scoring_mask)
Finally, to make sure we can use our new model within the Hugging Face ecosystem, we need to register our model with the Hugging Face auto loading mechanism. We additionally need to register the BiEncoderTokenizer
to ensure it is loaded when loading our new model. We can do this by adding the following code to our model file:
from lightning_ir import BiEncoderTokenizer
from transformers import AutoConfig, AutoModel, AutoTokenizer
AutoConfig.register(CustomBiEncoderConfig.model_type, CustomBiEncoderConfig)
AutoModel.register(CustomBiEncoderConfig, CustomBiEncoderModel)
AutoTokenizer.register(CustomBiEncoderConfig, BiEncoderTokenizer)
Now we can use our custom cross-encoder model in the same way as the built-in models. For example, to fine-tune our custom bi-encoder model on the MS MARCO dataset, we can use the following code:
from torch.optim import AdamW
from lightning_ir import (
BiEncoderModule,
LightningIRDataModule,
LightningIRTrainer,
RankNet,
TupleDataset,
)
module = BiEncoderModule(
model_name_or_path="bert-base-uncased",
config=CustomBiEncoderConfig(), # our custom config
loss_functions=[RankNet()]
)
module.set_optimizer(AdamW, lr=1e-5)
data_module = LightningIRDataModule(
train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
train_batch_size=32,
)
trainer = LightningIRTrainer(max_steps=100_000)
trainer.fit(module, data_module)
Here is the full code for our custom bi-encoder model:
custom_bi_encoder.py
from typing import Literal
import torch
from torch.optim import AdamW
from transformers import AutoConfig, AutoModel, AutoTokenizer, BatchEncoding
from lightning_ir import (
BiEncoderConfig,
BiEncoderEmbedding,
BiEncoderModel,
BiEncoderModule,
BiEncoderTokenizer,
LightningIRDataModule,
LightningIRTrainer,
RankNet,
TupleDataset,
)
class CustomBiEncoderConfig(BiEncoderConfig):
model_type = "custom-bi-encoder"
ADDED_ARGS = BiEncoderConfig.ADDED_ARGS.union({"additional_linear_layer"})
def __init__(self, additional_linear_layer=True, **kwargs):
super().__init__(**kwargs)
self.additional_linear_layer = additional_linear_layer
class CustomBiEncoderModel(BiEncoderModel):
config_class = CustomBiEncoderConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.additional_linear_layer = None
if config.additional_linear_layer:
self.additional_linear_layer = torch.nn.Linear(config.hidden_size, config.hidden_size)
def encode(
self,
encoding: BatchEncoding,
expansion: bool = False,
pooling_strategy: Literal["first", "mean", "max", "sum"] | None = None,
mask_scoring_input_ids: torch.Tensor | None = None,
) -> BiEncoderEmbedding:
embeddings = self._backbone_forward(**encoding).last_hidden_state
if self.additional_linear_layer is not None: # apply additional linear layer
embeddings = self.additional_linear_layer(embeddings)
if self.projection is not None:
embeddings = self.projection(embeddings)
embeddings = self._sparsification(embeddings, self.config.sparsification)
embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy)
if self.config.normalize:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
scoring_mask = self.scoring_mask(
encoding["input_ids"],
encoding["attention_mask"],
expansion,
pooling_strategy,
mask_scoring_input_ids,
)
return BiEncoderEmbedding(embeddings, scoring_mask)
AutoConfig.register(CustomBiEncoderConfig.model_type, CustomBiEncoderConfig)
AutoModel.register(CustomBiEncoderConfig, CustomBiEncoderModel)
AutoTokenizer.register(CustomBiEncoderConfig, BiEncoderTokenizer)
module = BiEncoderModule(
model_name_or_path="bert-base-uncased",
config=CustomBiEncoderConfig(), # our custom config
loss_functions=[RankNet()],
)
module.set_optimizer(AdamW, lr=1e-5)
data_module = LightningIRDataModule(
train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
train_batch_size=2,
)
trainer = LightningIRTrainer(max_steps=100_000)
trainer.fit(module, data_module)
Cross-Encoder
Say we wanted to build a custom cross-encoder model that adds an additional linear layer on top of the pooled embeddings. If we wanted to make this option configurable, we would first need to subclass the CrossEncoderConfig
and add a new attribute for the additional linear layer. We must also assign a new model_type
to our model and make sure our additional attributes are included in the ADDED_ARGS
set to ensure saved models are loaded correctly. For example:
from lightning_ir import CrossEncoderConfig
class CustomCrossEncoderConfig(CrossEncoderConfig):
model_type = "custom-cross-encoder"
ADDED_ARGS = CrossEncoderConfig.ADDED_ARGS.union({"additional_linear_layer"})
def __init__(self, additional_linear_layer = True, **kwargs):
super().__init__(**kwargs)
self.additional_linear_layer = additional_linear_layer
Next, we need to subclass the CrossEncoderModel
and override the forward()
method to include the additional linear layer. We also need to ensure that our new config class is registered with our new model as the config_class
attribute. In the forward()
method, the _backbone_forward()
method runs the backbone model and returns the contextualized embeddings of the input sequence. The _pooling()
method aggregates the embeddings based on the pooling strategy specified in the config. We then apply our additional linear layer to the pooled embeddings and finally use a linear layer to compute the final relevance score. For example:
import torch
from transformers import BatchEncoding
from lightning_ir import CrossEncoderModel, CrossEncoderOutput
class CustomCrossEncoderModel(CrossEncoderModel):
config_class = CustomCrossEncoderConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.additional_linear_layer = None
if config.additional_linear_layer:
self.additional_linear_layer = torch.nn.Linear(
config.hidden_size, config.hidden_size
)
def forward(self, encoding: BatchEncoding) -> torch.Tensor:
embeddings = self._backbone_forward(**encoding).last_hidden_state
embeddings = self._pooling(
embeddings,
encoding.get("attention_mask", None),
pooling_strategy=self.config.pooling_strategy,
)
if self.additional_linear_layer is not None:
embeddings = self.additional_linear_layer(embeddings)
scores = self.linear(embeddings).view(-1)
return CrossEncoderOutput(scores=scores, embeddings=embeddings)
Finally, to make sure we can use our new model within the Hugging Face ecosystem, we need to register our model with the Hugging Face auto loading mechanism. We additionally need to register the CrossEncoderTokenizer
to ensure it is loaded when loading our new model. We can do this by adding the following code to our model file:
from lightning_ir import CrossEncoderTokenizer
from transformers import AutoConfig, AutoModel, AutoTokenizer
AutoConfig.register(CustomCrossEncoderConfig.model_type, CustomCrossEncoderConfig)
AutoModel.register(CustomCrossEncoderConfig, CustomCrossEncoderModel)
AutoTokenizer.register(CustomCrossEncoderConfig, CrossEncoderTokenizer)
Now we can use our custom cross-encoder model in the same way as the built-in models. For example, to fine-tune our custom cross-encoder model on the MS MARCO dataset, we can use the following code:
from torch.optim import AdamW
from lightning_ir import (
CrossEncoderModule,
LightningIRDataModule,
LightningIRTrainer,
RankNet,
TupleDataset,
)
module = CrossEncoderModule(
model_name_or_path="bert-base-uncased",
config=CustomCrossEncoderConfig(), # our custom config
loss_functions=[RankNet()]
)
module.set_optimizer(AdamW, lr=1e-5)
data_module = LightningIRDataModule(
train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
train_batch_size=32,
)
trainer = LightningIRTrainer(max_steps=100_000)
trainer.fit(module, data_module)
Here is the full code for our custom cross-encoder model:
custom_cross_encoder.py
import torch
from torch.optim import AdamW
from transformers import AutoConfig, AutoModel, AutoTokenizer, BatchEncoding
from lightning_ir import (
CrossEncoderModel,
CrossEncoderModule,
CrossEncoderOutput,
CrossEncoderTokenizer,
LightningIRDataModule,
LightningIRTrainer,
RankNet,
TupleDataset,
)
from lightning_ir.cross_encoder.config import CrossEncoderConfig
class CustomCrossEncoderConfig(CrossEncoderConfig):
model_type = "custom-cross-encoder"
ADDED_ARGS = CrossEncoderConfig.ADDED_ARGS.union({"additional_linear_layer"})
def __init__(self, additional_linear_layer: bool = True, **kwargs):
super().__init__(**kwargs)
self.additional_linear_layer = additional_linear_layer
class CustomCrossEncoderModel(CrossEncoderModel):
config_class = CustomCrossEncoderConfig
def __init__(self, config: CustomCrossEncoderConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.additional_linear_layer = None
if config.additional_linear_layer:
self.additional_linear_layer = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, encoding: BatchEncoding) -> torch.Tensor:
embeddings = self._backbone_forward(**encoding).last_hidden_state
embeddings = self._pooling(
embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy
)
if self.additional_linear_layer is not None:
embeddings = self.additional_linear_layer(embeddings)
scores = self.linear(embeddings).view(-1)
return CrossEncoderOutput(scores=scores, embeddings=embeddings)
# register the config, model and tokenizer with the transformers Auto* classes
AutoConfig.register(CustomCrossEncoderConfig.model_type, CustomCrossEncoderConfig)
AutoModel.register(CustomCrossEncoderConfig, CustomCrossEncoderModel)
AutoTokenizer.register(CustomCrossEncoderConfig, CrossEncoderTokenizer)
# Fine-tune our custom model
module = CrossEncoderModule(
model_name_or_path="bert-base-uncased",
config=CustomCrossEncoderConfig(), # our custom config
loss_functions=[RankNet()],
)
module.set_optimizer(AdamW, lr=1e-5)
data_module = LightningIRDataModule(
train_dataset=TupleDataset("msmarco-passage/train/triples-small"), train_batch_size=32
)
trainer = LightningIRTrainer(max_steps=100_000)
trainer.fit(module, data_module)