CrossEncoderModel
- class lightning_ir.cross_encoder.model.CrossEncoderModel(config: CrossEncoderConfig, *args, **kwargs)[source]
Bases:
LightningIRModel
- __init__(config: CrossEncoderConfig, *args, **kwargs)[source]
A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are aggragated into a single vector and fed to a linear layer which computes a final relevance score.
- Parameters:
config (CrossEncoderConfig) – Configuration for the cross-encoder model
Methods
__init__
(config, *args, **kwargs)A cross-encoder model that jointly encodes a query and document(s).
forward
(encoding)Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.
from_pretrained
(model_name_or_path, *args, ...)Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method and to return a
Attributes
ALLOW_SUB_BATCHING
Flag to allow mini batches of documents for a single query.
- config_class
Configuration class for cross-encoder models.
alias of
CrossEncoderConfig
- forward(encoding: BatchEncoding) CrossEncoderOutput [source]
Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.
- Parameters:
encoding (BatchEncoding) – Tokenizer encoding for the joint query-document input sequence
- Returns:
Output of the model
- Return type:
- classmethod from_pretrained(model_name_or_path: str | Path, *args, **kwargs) LightningIRModel
- Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method and to return a
derived LightningIRModel. See
LightningIRModelClassFactory
for more details.
- param model_name_or_path:
Name or path of the pretrained model
- type model_name_or_path:
str | Path
- raises ValueError:
If called on the abstract class
LightningIRModel
and no config is passed- return:
A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
- rtype:
LightningIRModel
>>> # Loading using model class and backbone checkpoint >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> >>> # Loading using base class and backbone checkpoint >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>