1"""
2Class factory module for Lightning IR.
3
4This module provides factory classes for creating various components of the Lightning IR library
5by extending Hugging Face Transformers classes.
6"""
7
8from __future__ import annotations
9
10from abc import ABC, abstractmethod
11from pathlib import Path
12from typing import TYPE_CHECKING, Any, Tuple, Type
13
14from transformers import (
15 CONFIG_MAPPING,
16 MODEL_MAPPING,
17 TOKENIZER_MAPPING,
18 PretrainedConfig,
19 PreTrainedModel,
20 PreTrainedTokenizerBase,
21)
22from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
23
24if TYPE_CHECKING:
25 from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
26
27
[docs]
28class LightningIRClassFactory(ABC):
29 """Base class for creating derived Lightning IR classes from HuggingFace classes."""
30
[docs]
31 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None:
32 """Creates a new LightningIRClassFactory.
33
34 :param MixinConfig: LightningIRConfig mixin class
35 :type MixinConfig: Type[LightningIRConfig]
36 """
37 if getattr(MixinConfig, "backbone_model_type", None) is not None:
38 MixinConfig = MixinConfig.__bases__[0]
39 self.MixinConfig = MixinConfig
40
[docs]
41 @staticmethod
42 def get_backbone_config(model_name_or_path: str | Path) -> Type[PretrainedConfig]:
43 """Grabs the configuration class from a checkpoint of a pretrained HuggingFace model.
44
45 :param model_name_or_path: Path to the model or its name
46 :type model_name_or_path: str | Path
47 :return: Configuration class of the backbone model
48 :rtype: PretrainedConfig
49 """
50 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path)
51 return CONFIG_MAPPING[backbone_model_type]
52
[docs]
53 @staticmethod
54 def get_lightning_ir_config(model_name_or_path: str | Path) -> Type[LightningIRConfig] | None:
55 """Grabs the Lightning IR configuration class from a checkpoint of a pretrained Lightning IR model.
56
57 :param model_name_or_path: Path to the model or its name
58 :type model_name_or_path: str | Path
59 :return: Configuration class of the Lightning IR model
60 :rtype: Type[LightningIRConfig]
61 """
62 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path)
63 if model_type is None:
64 return None
65 return CONFIG_MAPPING[model_type]
66
[docs]
67 @staticmethod
68 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
69 """Grabs the model type from a checkpoint of a pretrained HuggingFace model.
70
71 :param model_name_or_path: Path to the model or its name
72 :type model_name_or_path: str | Path
73 :return: Model type of the backbone model
74 :rtype: str
75 """
76 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs)
77 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type")
78 if backbone_model_type is None:
79 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}")
80 return backbone_model_type
81
[docs]
82 @staticmethod
83 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None:
84 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model.
85
86 :param model_name_or_path: Path to the model or its name
87 :type model_name_or_path: str | Path
88 :return: Model type of the Lightning IR model
89 :rtype: str | None
90 """
91 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path)
92 if "backbone_model_type" not in config_dict:
93 return None
94 return config_dict.get("model_type", None)
95
96 @property
97 def cc_lir_model_type(self) -> str:
98 """Camel case model type of the Lightning IR model."""
99 return "".join(s.title() for s in self.MixinConfig.model_type.split("-"))
100
[docs]
101 @abstractmethod
102 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any:
103 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses.
104
105 :param model_name_or_path: Path to the model or its name
106 :type model_name_or_path: str | Path
107 :return: Derived Lightning IR class
108 :rtype: Any
109 """
110 ...
111
[docs]
112 @abstractmethod
113 def from_backbone_class(self, BackboneClass: Type) -> Type:
114 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses.
115
116 :param BackboneClass: Backbone class
117 :type BackboneClass: Type
118 :return: Derived Lightning IR class
119 :rtype: Type
120 """
121 ...
122
123
[docs]
124class LightningIRConfigClassFactory(LightningIRClassFactory):
125 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes."""
126
[docs]
127 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]:
128 """Loads a derived LightningIRConfig from a pretrained HuggingFace model.
129
130 :param model_name_or_path: Path to the model or its name
131 :type model_name_or_path: str | Path
132 :return: Derived LightningIRConfig
133 :rtype: Type[LightningIRConfig]
134 """
135 BackboneConfig = self.get_backbone_config(model_name_or_path)
136 DerivedLightningIRConfig = self.from_backbone_class(BackboneConfig)
137 return DerivedLightningIRConfig
138
[docs]
139 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]:
140 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If
141 the backbone configuration class is already a dervied LightningIRConfig, it is returned as is.
142
143 .. _transformers.PretrainedConfig: \
144https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig
145
146 :param BackboneClass: Backbone configuration class
147 :type BackboneClass: Type[PretrainedConfig]
148 :return: Derived LightningIRConfig
149 :rtype: Type[LightningIRConfig]
150 """
151 if getattr(BackboneClass, "backbone_model_type", None) is not None:
152 return BackboneClass
153 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type]
154
155 DerivedLightningIRConfig = type(
156 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
157 (LightningIRConfigMixin, BackboneClass),
158 {
159 "model_type": self.MixinConfig.model_type,
160 "backbone_model_type": BackboneClass.model_type,
161 "mixin_config": self.MixinConfig,
162 },
163 )
164 return DerivedLightningIRConfig
165
166
[docs]
167class LightningIRModelClassFactory(LightningIRClassFactory):
168 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes."""
169
[docs]
170 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]:
171 """Loads a derived LightningIRModel from a pretrained HuggingFace model.
172
173 :param model_name_or_path: Path to the model or its name
174 :type model_name_or_path: str | Path
175 :return: Derived LightningIRModel
176 :rtype: Type[LightningIRModel]
177 """
178 BackboneConfig = self.get_backbone_config(model_name_or_path)
179 BackboneModel = MODEL_MAPPING[BackboneConfig]
180 DerivedLightningIRModel = self.from_backbone_class(BackboneModel)
181 return DerivedLightningIRModel
182
[docs]
183 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]:
184 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model
185 is already a LightningIRModel, it is returned as is.
186
187 .. _transformers.PreTrainedModel: \
188https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel
189
190 :param BackboneClass: Backbone model
191 :type BackboneClass: Type[PreTrainedModel]
192 :raises ValueError: If the backbone model is not a valid backbone model.
193 :raises ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed.
194 :raises ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping.
195 :return: The derived LightningIRModel
196 :rtype: Type[LightningIRModel]
197 """
198 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None:
199 return BackboneClass
200 BackboneConfig = BackboneClass.config_class
201 if BackboneConfig is None:
202 raise ValueError(
203 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`."
204 )
205
206 LightningIRModelMixin: Type[LightningIRModel] = MODEL_MAPPING[self.MixinConfig]
207
208 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig)
209
210 DerivedLightningIRModel = type(
211 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
212 (LightningIRModelMixin, BackboneClass),
213 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward},
214 )
215 return DerivedLightningIRModel
216
217
[docs]
218class LightningIRTokenizerClassFactory(LightningIRClassFactory):
219 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes."""
220
[docs]
221 @staticmethod
222 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
223 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer.
224
225 :param model_name_or_path: Path to the tokenizer or its name
226 :type model_name_or_path: str | Path
227 :return: Configuration class of the backbone tokenizer
228 :rtype: PretrainedConfig
229 """
230 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path)
231 return CONFIG_MAPPING[backbone_model_type]
232
[docs]
233 @staticmethod
234 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
235 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer.
236
237 :param model_name_or_path: Path to the tokenizer or its name
238 :type model_name_or_path: str | Path
239 :return: Model type of the backbone tokenizer
240 :rtype: str
241 """
242 try:
243 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs)
244 except (OSError, ValueError):
245 # best guess at model type
246 config_dict = get_tokenizer_config(model_name_or_path)
247 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None)
248 if backbone_tokenizer_class is not None:
249 Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class)
250 for config, tokenizers in TOKENIZER_MAPPING.items():
251 if Tokenizer in tokenizers:
252 return getattr(config, "model_type")
253 raise ValueError("No backbone model found in the configuration")
254
[docs]
255 def from_pretrained(
256 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs
257 ) -> Type[LightningIRTokenizer]:
258 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer.
259
260 :param model_name_or_path: Path to the tokenizer or its name
261 :type model_name_or_path: str | Path
262 :param use_fast: Whether to use the fast or slow tokenizer, defaults to True
263 :type use_fast: bool, optional
264 :raises ValueError: If use_fast is True and no fast tokenizer is found
265 :raises ValueError: If use_fast is False and no slow tokenizer is found
266 :return: Derived LightningIRTokenizer
267 :rtype: Type[LightningIRTokenizer]
268 """
269 BackboneConfig = self.get_backbone_config(model_name_or_path)
270 BackboneTokenizers = TOKENIZER_MAPPING[BackboneConfig]
271 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, BackboneConfig)
272 if use_fast:
273 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1]
274 if DerivedLightningIRTokenizer is None:
275 raise ValueError("No fast tokenizer found.")
276 else:
277 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0]
278 if DerivedLightningIRTokenizer is None:
279 raise ValueError("No slow tokenizer found.")
280 return DerivedLightningIRTokenizer
281
[docs]
282 def from_backbone_classes(
283 self,
284 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None],
285 BackboneConfig: Type[PretrainedConfig] | None = None,
286 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]:
287 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes.
288
289 :param BackboneClasses: Slow and fast backbone tokenizer classes
290 :type BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None]
291 :param BackboneConfig: Backbone configuration class, defaults to None
292 :type BackboneConfig: Type[PretrainedConfig], optional
293 :return: Slow and fast derived LightningIRTokenizers
294 :rtype: Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]
295 """
296 DerivedLightningIRTokenizers = tuple(
297 None if BackboneClass is None else self.from_backbone_class(BackboneClass)
298 for BackboneClass in BackboneClasses
299 )
300 if DerivedLightningIRTokenizers[1] is not None:
301 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0]
302 return DerivedLightningIRTokenizers
303
[docs]
304 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]:
305 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If
306 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is.
307
308 .. _transformers.PreTrainedTokenizerBase: \
309https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase
310
311 :param BackboneClass: Backbone tokenizer class
312 :type BackboneClass: Type[PreTrainedTokenizerBase]
313 :return: Derived LightningIRTokenizer
314 :rtype: Type[LightningIRTokenizer]
315 """
316 if hasattr(BackboneClass, "config_class"):
317 return BackboneClass
318 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0]
319
320 DerivedLightningIRTokenizer = type(
321 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {}
322 )
323
324 return DerivedLightningIRTokenizer