Source code for lightning_ir.base.config

  1"""
  2Base configuration class for Lightning IR models.
  3
  4This module defines the configuration class `LightningIRConfig` which is used to instantiate
  5a Lightning IR model. The configuration class acts as a mixin for the `transformers.PretrainedConfig`
  6class from the Hugging Face Transformers library.
  7"""
  8
  9from pathlib import Path
 10from typing import Any, Dict, Set
 11
 12from transformers import PretrainedConfig
 13
 14from .class_factory import LightningIRConfigClassFactory
 15from .external_model_hub import CHECKPOINT_MAPPING
 16
 17
[docs] 18class LightningIRConfig(PretrainedConfig): 19 """The configuration class to instantiate a Lightning IR model. Acts as a mixin for the 20 transformers.PretrainedConfig_ class. 21 22 .. _transformers.PretrainedConfig: \ 23https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig 24 """ 25 26 model_type = "lightning-ir" 27 """Model type for the configuration.""" 28 backbone_model_type: str | None = None 29 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`.""" 30 31 TOKENIZER_ARGS: Set[str] = {"query_length", "doc_length"} 32 """Arguments for the tokenizer.""" 33 ADDED_ARGS: Set[str] = TOKENIZER_ARGS 34 """Arguments added to the configuration.""" 35
[docs] 36 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 37 """Initializes the configuration. 38 39 :param query_length: Maximum query length, defaults to 32 40 :type query_length: int, optional 41 :param doc_length: Maximum document length, defaults to 512 42 :type doc_length: int, optional 43 """ 44 super().__init__(*args, **kwargs) 45 self.query_length = query_length 46 self.doc_length = doc_length
47
[docs] 48 def to_added_args_dict(self) -> Dict[str, Any]: 49 """Outputs a dictionary of the added arguments. 50 51 :return: Added arguments 52 :rtype: Dict[str, Any] 53 """ 54 return {arg: getattr(self, arg) for arg in self.ADDED_ARGS if hasattr(self, arg)}
55
[docs] 56 def to_tokenizer_dict(self) -> Dict[str, Any]: 57 """Outputs a dictionary of the tokenizer arguments. 58 59 :return: Tokenizer arguments 60 :rtype: Dict[str, Any] 61 """ 62 return {arg: getattr(self, arg) for arg in self.TOKENIZER_ARGS}
63
[docs] 64 def to_dict(self) -> Dict[str, Any]: 65 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone 66 model type. 67 68 .. _transformers.PretrainedConfig.to_dict: \ 69https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict 70 71 :return: Configuration dictionary 72 :rtype: Dict[str, Any] 73 """ 74 output = getattr(super(), "to_dict")() 75 if self.backbone_model_type is not None: 76 output["backbone_model_type"] = self.backbone_model_type 77 return output
78
[docs] 79 @classmethod 80 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig": 81 """Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained_ 82 83 .. _transformers.PretrainedConfig.from_pretrained: \ 84https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.from_pretrained 85 86 :param pretrained_model_name_or_path: Pretrained model name or path 87 :type pretrained_model_name_or_path: str | Path 88 :raises ValueError: If `pre_trained_model_name_or_path` is not a Lightning IR model and no 89 :py:class:`LightningIRConfig` is passed 90 :return: Derived LightningIRConfig class 91 :rtype: LightningIRConfig 92 """ 93 # provides AutoConfig.from_pretrained support 94 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__): 95 # no backbone config found, create dervied lightning-ir config based on backbone config 96 config = None 97 if pretrained_model_name_or_path in CHECKPOINT_MAPPING: 98 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path] 99 config_class = config.__class__ 100 elif cls is not LightningIRConfig: 101 config_class = cls 102 else: 103 config_class = LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path) 104 if config_class is None: 105 raise ValueError("Pass a config to `from_pretrained`.") 106 BackboneConfig = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path) 107 cls = LightningIRConfigClassFactory(config_class).from_backbone_class(BackboneConfig) 108 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 109 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config) 110 derived_config.update(config.to_dict()) 111 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) 112 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)