Source code for lightning_ir.base.config
1"""
2Base configuration class for Lightning IR models.
3
4This module defines the configuration class `LightningIRConfig` which is used to instantiate
5a Lightning IR model. The configuration class acts as a mixin for the `transformers.PretrainedConfig`
6class from the Hugging Face Transformers library.
7"""
8
9from pathlib import Path
10from typing import Any, Dict, Set
11
12from transformers import PretrainedConfig
13
14from .class_factory import LightningIRConfigClassFactory
15from .external_model_hub import CHECKPOINT_MAPPING
16
17
[docs]
18class LightningIRConfig(PretrainedConfig):
19 """The configuration class to instantiate a Lightning IR model. Acts as a mixin for the
20 transformers.PretrainedConfig_ class.
21
22 .. _transformers.PretrainedConfig: \
23https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig
24 """
25
26 model_type = "lightning-ir"
27 """Model type for the configuration."""
28 backbone_model_type: str | None = None
29 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`."""
30
31 TOKENIZER_ARGS: Set[str] = {"query_length", "doc_length"}
32 """Arguments for the tokenizer."""
33 ADDED_ARGS: Set[str] = TOKENIZER_ARGS
34 """Arguments added to the configuration."""
35
[docs]
36 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
37 """Initializes the configuration.
38
39 :param query_length: Maximum query length, defaults to 32
40 :type query_length: int, optional
41 :param doc_length: Maximum document length, defaults to 512
42 :type doc_length: int, optional
43 """
44 super().__init__(*args, **kwargs)
45 self.query_length = query_length
46 self.doc_length = doc_length
47
[docs]
48 def to_added_args_dict(self) -> Dict[str, Any]:
49 """Outputs a dictionary of the added arguments.
50
51 :return: Added arguments
52 :rtype: Dict[str, Any]
53 """
54 return {arg: getattr(self, arg) for arg in self.ADDED_ARGS if hasattr(self, arg)}
55
[docs]
56 def to_tokenizer_dict(self) -> Dict[str, Any]:
57 """Outputs a dictionary of the tokenizer arguments.
58
59 :return: Tokenizer arguments
60 :rtype: Dict[str, Any]
61 """
62 return {arg: getattr(self, arg) for arg in self.TOKENIZER_ARGS}
63
[docs]
64 def to_dict(self) -> Dict[str, Any]:
65 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone
66 model type.
67
68 .. _transformers.PretrainedConfig.to_dict: \
69https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict
70
71 :return: Configuration dictionary
72 :rtype: Dict[str, Any]
73 """
74 output = getattr(super(), "to_dict")()
75 if self.backbone_model_type is not None:
76 output["backbone_model_type"] = self.backbone_model_type
77 return output
78
[docs]
79 @classmethod
80 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig":
81 """Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained_
82
83 .. _transformers.PretrainedConfig.from_pretrained: \
84https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.from_pretrained
85
86 :param pretrained_model_name_or_path: Pretrained model name or path
87 :type pretrained_model_name_or_path: str | Path
88 :raises ValueError: If `pre_trained_model_name_or_path` is not a Lightning IR model and no
89 :py:class:`LightningIRConfig` is passed
90 :return: Derived LightningIRConfig class
91 :rtype: LightningIRConfig
92 """
93 # provides AutoConfig.from_pretrained support
94 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__):
95 # no backbone config found, create dervied lightning-ir config based on backbone config
96 config = None
97 if pretrained_model_name_or_path in CHECKPOINT_MAPPING:
98 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path]
99 config_class = config.__class__
100 elif cls is not LightningIRConfig:
101 config_class = cls
102 else:
103 config_class = LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path)
104 if config_class is None:
105 raise ValueError("Pass a config to `from_pretrained`.")
106 BackboneConfig = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path)
107 cls = LightningIRConfigClassFactory(config_class).from_backbone_class(BackboneConfig)
108 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
109 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config)
110 derived_config.update(config.to_dict())
111 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
112 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)