Source code for lightning_ir.base.class_factory

  1"""
  2Class factory module for Lightning IR.
  3
  4This module provides factory classes for creating various components of the Lightning IR library
  5by extending Hugging Face Transformers classes.
  6"""
  7
  8from __future__ import annotations
  9
 10from abc import ABC, abstractmethod
 11from pathlib import Path
 12from typing import TYPE_CHECKING, Any, Tuple, Type
 13
 14from transformers import (
 15    CONFIG_MAPPING,
 16    MODEL_MAPPING,
 17    TOKENIZER_MAPPING,
 18    PretrainedConfig,
 19    PreTrainedModel,
 20    PreTrainedTokenizerBase,
 21)
 22from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
 23
 24if TYPE_CHECKING:
 25    from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
 26
 27
[docs] 28class LightningIRClassFactory(ABC): 29 """Base class for creating derived Lightning IR classes from HuggingFace classes.""" 30
[docs] 31 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None: 32 """Creates a new LightningIRClassFactory. 33 34 :param MixinConfig: LightningIRConfig mixin class 35 :type MixinConfig: Type[LightningIRConfig] 36 """ 37 if getattr(MixinConfig, "backbone_model_type", None) is not None: 38 MixinConfig = MixinConfig.__bases__[0] 39 self.MixinConfig = MixinConfig
40
[docs] 41 @staticmethod 42 def get_backbone_config(model_name_or_path: str | Path) -> Type[PretrainedConfig]: 43 """Grabs the configuration class from a checkpoint of a pretrained HuggingFace model. 44 45 :param model_name_or_path: Path to the model or its name 46 :type model_name_or_path: str | Path 47 :return: Configuration class of the backbone model 48 :rtype: PretrainedConfig 49 """ 50 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path) 51 return CONFIG_MAPPING[backbone_model_type]
52
[docs] 53 @staticmethod 54 def get_lightning_ir_config(model_name_or_path: str | Path) -> Type[LightningIRConfig] | None: 55 """Grabs the Lightning IR configuration class from a checkpoint of a pretrained Lightning IR model. 56 57 :param model_name_or_path: Path to the model or its name 58 :type model_name_or_path: str | Path 59 :return: Configuration class of the Lightning IR model 60 :rtype: Type[LightningIRConfig] 61 """ 62 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path) 63 if model_type is None: 64 return None 65 return CONFIG_MAPPING[model_type]
66
[docs] 67 @staticmethod 68 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 69 """Grabs the model type from a checkpoint of a pretrained HuggingFace model. 70 71 :param model_name_or_path: Path to the model or its name 72 :type model_name_or_path: str | Path 73 :return: Model type of the backbone model 74 :rtype: str 75 """ 76 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs) 77 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type") 78 if backbone_model_type is None: 79 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}") 80 return backbone_model_type
81
[docs] 82 @staticmethod 83 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None: 84 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model. 85 86 :param model_name_or_path: Path to the model or its name 87 :type model_name_or_path: str | Path 88 :return: Model type of the Lightning IR model 89 :rtype: str | None 90 """ 91 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path) 92 if "backbone_model_type" not in config_dict: 93 return None 94 return config_dict.get("model_type", None)
95 96 @property 97 def cc_lir_model_type(self) -> str: 98 """Camel case model type of the Lightning IR model.""" 99 return "".join(s.title() for s in self.MixinConfig.model_type.split("-")) 100
[docs] 101 @abstractmethod 102 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any: 103 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses. 104 105 :param model_name_or_path: Path to the model or its name 106 :type model_name_or_path: str | Path 107 :return: Derived Lightning IR class 108 :rtype: Any 109 """ 110 ...
111
[docs] 112 @abstractmethod 113 def from_backbone_class(self, BackboneClass: Type) -> Type: 114 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses. 115 116 :param BackboneClass: Backbone class 117 :type BackboneClass: Type 118 :return: Derived Lightning IR class 119 :rtype: Type 120 """ 121 ...
122 123
[docs] 124class LightningIRConfigClassFactory(LightningIRClassFactory): 125 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes.""" 126
[docs] 127 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]: 128 """Loads a derived LightningIRConfig from a pretrained HuggingFace model. 129 130 :param model_name_or_path: Path to the model or its name 131 :type model_name_or_path: str | Path 132 :return: Derived LightningIRConfig 133 :rtype: Type[LightningIRConfig] 134 """ 135 BackboneConfig = self.get_backbone_config(model_name_or_path) 136 DerivedLightningIRConfig = self.from_backbone_class(BackboneConfig) 137 return DerivedLightningIRConfig
138
[docs] 139 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]: 140 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If 141 the backbone configuration class is already a dervied LightningIRConfig, it is returned as is. 142 143 .. _transformers.PretrainedConfig: \ 144https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig 145 146 :param BackboneClass: Backbone configuration class 147 :type BackboneClass: Type[PretrainedConfig] 148 :return: Derived LightningIRConfig 149 :rtype: Type[LightningIRConfig] 150 """ 151 if getattr(BackboneClass, "backbone_model_type", None) is not None: 152 return BackboneClass 153 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type] 154 155 DerivedLightningIRConfig = type( 156 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 157 (LightningIRConfigMixin, BackboneClass), 158 { 159 "model_type": self.MixinConfig.model_type, 160 "backbone_model_type": BackboneClass.model_type, 161 "mixin_config": self.MixinConfig, 162 }, 163 ) 164 return DerivedLightningIRConfig
165 166
[docs] 167class LightningIRModelClassFactory(LightningIRClassFactory): 168 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes.""" 169
[docs] 170 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]: 171 """Loads a derived LightningIRModel from a pretrained HuggingFace model. 172 173 :param model_name_or_path: Path to the model or its name 174 :type model_name_or_path: str | Path 175 :return: Derived LightningIRModel 176 :rtype: Type[LightningIRModel] 177 """ 178 BackboneConfig = self.get_backbone_config(model_name_or_path) 179 BackboneModel = MODEL_MAPPING[BackboneConfig] 180 DerivedLightningIRModel = self.from_backbone_class(BackboneModel) 181 return DerivedLightningIRModel
182
[docs] 183 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]: 184 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model 185 is already a LightningIRModel, it is returned as is. 186 187 .. _transformers.PreTrainedModel: \ 188https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel 189 190 :param BackboneClass: Backbone model 191 :type BackboneClass: Type[PreTrainedModel] 192 :raises ValueError: If the backbone model is not a valid backbone model. 193 :raises ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed. 194 :raises ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping. 195 :return: The derived LightningIRModel 196 :rtype: Type[LightningIRModel] 197 """ 198 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None: 199 return BackboneClass 200 BackboneConfig = BackboneClass.config_class 201 if BackboneConfig is None: 202 raise ValueError( 203 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`." 204 ) 205 206 LightningIRModelMixin: Type[LightningIRModel] = MODEL_MAPPING[self.MixinConfig] 207 208 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig) 209 210 DerivedLightningIRModel = type( 211 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 212 (LightningIRModelMixin, BackboneClass), 213 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward}, 214 ) 215 return DerivedLightningIRModel
216 217
[docs] 218class LightningIRTokenizerClassFactory(LightningIRClassFactory): 219 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes.""" 220
[docs] 221 @staticmethod 222 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 223 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer. 224 225 :param model_name_or_path: Path to the tokenizer or its name 226 :type model_name_or_path: str | Path 227 :return: Configuration class of the backbone tokenizer 228 :rtype: PretrainedConfig 229 """ 230 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path) 231 return CONFIG_MAPPING[backbone_model_type]
232
[docs] 233 @staticmethod 234 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 235 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer. 236 237 :param model_name_or_path: Path to the tokenizer or its name 238 :type model_name_or_path: str | Path 239 :return: Model type of the backbone tokenizer 240 :rtype: str 241 """ 242 try: 243 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs) 244 except (OSError, ValueError): 245 # best guess at model type 246 config_dict = get_tokenizer_config(model_name_or_path) 247 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None) 248 if backbone_tokenizer_class is not None: 249 Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class) 250 for config, tokenizers in TOKENIZER_MAPPING.items(): 251 if Tokenizer in tokenizers: 252 return getattr(config, "model_type") 253 raise ValueError("No backbone model found in the configuration")
254
[docs] 255 def from_pretrained( 256 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs 257 ) -> Type[LightningIRTokenizer]: 258 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer. 259 260 :param model_name_or_path: Path to the tokenizer or its name 261 :type model_name_or_path: str | Path 262 :param use_fast: Whether to use the fast or slow tokenizer, defaults to True 263 :type use_fast: bool, optional 264 :raises ValueError: If use_fast is True and no fast tokenizer is found 265 :raises ValueError: If use_fast is False and no slow tokenizer is found 266 :return: Derived LightningIRTokenizer 267 :rtype: Type[LightningIRTokenizer] 268 """ 269 BackboneConfig = self.get_backbone_config(model_name_or_path) 270 BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig] 271 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, BackboneConfig) 272 if use_fast: 273 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1] 274 if DerivedLightningIRTokenizer is None: 275 raise ValueError("No fast tokenizer found.") 276 else: 277 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0] 278 if DerivedLightningIRTokenizer is None: 279 raise ValueError("No slow tokenizer found.") 280 return DerivedLightningIRTokenizer
281
[docs] 282 def from_backbone_classes( 283 self, 284 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None], 285 BackboneConfig: Type[PretrainedConfig] | None = None, 286 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: 287 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes. 288 289 :param BackboneClasses: Slow and fast backbone tokenizer classes 290 :type BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None] 291 :param BackboneConfig: Backbone configuration class, defaults to None 292 :type BackboneConfig: Type[PretrainedConfig], optional 293 :return: Slow and fast derived LightningIRTokenizers 294 :rtype: Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None] 295 """ 296 DerivedLightningIRTokenizers = tuple( 297 None if BackboneClass is None else self.from_backbone_class(BackboneClass) 298 for BackboneClass in BackboneClasses 299 ) 300 if DerivedLightningIRTokenizers[1] is not None: 301 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0] 302 return DerivedLightningIRTokenizers
303
[docs] 304 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]: 305 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If 306 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is. 307 308 .. _transformers.PreTrainedTokenizerBase: \ 309https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase 310 311 :param BackboneClass: Backbone tokenizer class 312 :type BackboneClass: Type[PreTrainedTokenizerBase] 313 :return: Derived LightningIRTokenizer 314 :rtype: Type[LightningIRTokenizer] 315 """ 316 if hasattr(BackboneClass, "config_class"): 317 return BackboneClass 318 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0] 319 320 DerivedLightningIRTokenizer = type( 321 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {} 322 ) 323 324 return DerivedLightningIRTokenizer