1"""
2Model module for Lightning IR.
3
4This module contains the main model class and output class for the Lightning IR library.
5"""
6
7from collections import defaultdict
8from dataclasses import dataclass
9from functools import partial, wraps
10from pathlib import Path
11from typing import Any, Literal, Mapping, Protocol, Sequence, Type, TypeVar
12
13import torch
14from transformers import MODEL_MAPPING, BatchEncoding, BertModel
15from transformers.modeling_outputs import ModelOutput
16
17from .._flash import FLASH_ATTENTION_MAP
18from .class_factory import LightningIRModelClassFactory
19from .config import LightningIRConfig
20from .external_model_hub import CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
21
22
[docs]
23@dataclass
24class LightningIROutput(ModelOutput):
25 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_.
26
27 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput
28
29 :param scores: Output relevance scores for query--document pairs, defaults to None
30 :type scores: torch.Tensor | None, optional
31 """
32
33 scores: torch.Tensor | None = None
34
35
[docs]
36class LightningIRModel:
37 """Base class for Lightning IR models. Derived classes implement the forward method for handling query
38 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
39
40 .. _transformers.PreTrainedModel: \
41https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel
42 """
43
44 config_class: Type[LightningIRConfig] = LightningIRConfig
45 """Configuration class for the model."""
46
47 ALLOW_SUB_BATCHING = True
48 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure
49 correctness."""
50
[docs]
51 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None:
52 """Initializes the model.
53
54 :param config: Configuration class for the model
55 :type config: LightningIRConfig
56 """
57 super().__init__(config, *args, **kwargs)
58 self.config = config
59
60 self._sub_batch_size: int | None = None
61
62 if self.config.backbone_model_type is not None:
63 flash_attn = FLASH_ATTENTION_MAP.get(self.config.backbone_model_type, None)
64 if flash_attn is not None:
65 flash_attn_forward, self_attn_pattern = flash_attn
66 for name, module in self.named_modules():
67 if name.endswith(self_attn_pattern):
68 module.forward = partial(flash_attn_forward, module)
69
70 def _backbone_forward(self, *args, **kwargs):
71 """Runs the forward method of the backbone model. Is overridden in
72 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`.
73
74 :raises NotImplementedError: If not overridden in the derived class
75 """
76 raise NotImplementedError
77
[docs]
78 def forward(self, *args, **kwargs) -> LightningIROutput:
79 """Forward method of the model. Must be implemented by the derived class."""
80 raise NotImplementedError
81
82 def _sparsification(
83 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log"] | None = None
84 ) -> torch.Tensor:
85 """Helper method to apply sparsification to the embeddings.
86
87 :param embeddings: Query or document embeddings
88 :type embeddings: torch.Tensor
89 :param sparsification_strategy: The sparsification strategy. No sparsification is applied if None,
90 defaults to None
91 :type sparsification_strategy: Literal['relu', 'relu_log'] | None, optional
92 :raises ValueError: If an unknown sparsification strategy is passed
93 :return: (Optionally) sparsified embeddings
94 :rtype: torch.Tensor
95 """
96 if sparsification_strategy is None:
97 return embeddings
98 if sparsification_strategy == "relu":
99 return torch.relu(embeddings)
100 if sparsification_strategy == "relu_log":
101 return torch.log1p(torch.relu(embeddings))
102 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}")
103
104 def _pooling(
105 self,
106 embeddings: torch.Tensor,
107 attention_mask: torch.Tensor | None,
108 pooling_strategy: Literal["first", "mean", "max", "sum"] | None,
109 ) -> torch.Tensor:
110 """Helper method to apply pooling to the embeddings.
111
112 :param embeddings: Query or document embeddings
113 :type embeddings: torch.Tensor
114 :param attention_mask: Query or document attention mask
115 :type attention_mask: torch.Tensor | None
116 :param pooling_strategy: The pooling strategy. No pooling is applied if None.
117 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None
118 :raises ValueError: If an unknown pooling strategy is passed
119 :return: (Optionally) pooled embeddings
120 :rtype: torch.Tensor
121 """
122 if pooling_strategy is None:
123 return embeddings
124 if pooling_strategy == "first":
125 return embeddings[:, [0]]
126 if pooling_strategy in ("sum", "mean"):
127 if attention_mask is not None:
128 embeddings = embeddings * attention_mask.unsqueeze(-1)
129 embeddings = embeddings.sum(dim=1, keepdim=True)
130 if pooling_strategy == "mean":
131 if attention_mask is not None:
132 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1)
133 return embeddings
134 if pooling_strategy == "max":
135 if attention_mask is not None:
136 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), -1e9)
137 return embeddings.max(dim=1, keepdim=True).values
138 raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
139
140 @classmethod
141 def _load_pretrained_model(
142 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
143 ):
144 if pretrained_model_name_or_path in STATE_DICT_KEY_MAPPING:
145 map_keys = STATE_DICT_KEY_MAPPING[pretrained_model_name_or_path]
146 for orig_key, new_key in map_keys:
147 if orig_key is not None:
148 state_dict[new_key] = state_dict.pop(orig_key)
149 loaded_keys[loaded_keys.index(orig_key)] = new_key
150 else:
151 loaded_keys.append(new_key)
152 model, *out = super()._load_pretrained_model(
153 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs
154 )
155 if pretrained_model_name_or_path in POST_LOAD_CALLBACKS:
156 model = POST_LOAD_CALLBACKS[pretrained_model_name_or_path](model)
157 return (model, *out)
158
[docs]
159 @classmethod
160 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRModel":
161 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and to return a
162 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
163
164.. _transformers.PreTrainedModel.from_pretrained: \
165 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
166
167 :param model_name_or_path: Name or path of the pretrained model
168 :type model_name_or_path: str | Path
169 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed
170 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
171 :rtype: LightningIRModel
172
173 .. ::doctest
174 .. highlight:: python
175 .. code-block:: python
176
177 >>> # Loading using model class and backbone checkpoint
178 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
179 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
180 >>> # Loading using base class and backbone checkpoint
181 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
182 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
183 """
184 # provides AutoModel.from_pretrained support
185 config = kwargs.get("config", None)
186 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__):
187 # no backbone models found, create derived lightning-ir model based on backbone model
188 if config is not None:
189 config_class = config.__class__
190 elif model_name_or_path in CHECKPOINT_MAPPING:
191 _config = CHECKPOINT_MAPPING[model_name_or_path]
192 config_class = _config.__class__
193 if config is None:
194 config = _config
195 elif cls is not LightningIRModel:
196 config_class = cls.config_class
197 else:
198 config_class = LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path)
199 if config_class is None:
200 raise ValueError("Pass a config to `from_pretrained`.")
201 BackboneConfig = LightningIRModelClassFactory.get_backbone_config(model_name_or_path)
202 BackboneModel = MODEL_MAPPING[BackboneConfig]
203 cls = LightningIRModelClassFactory(config_class).from_backbone_class(BackboneModel)
204 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
205 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config)
206 derived_config.update(config.to_dict())
207 kwargs["config"] = derived_config
208 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
209 if issubclass(cls, BertModel):
210 kwargs["add_pooling_layer"] = False
211 return super(LightningIRModel, cls).from_pretrained(model_name_or_path, *args, **kwargs)
212
213
214T = TypeVar("T")
215
216
217def _cat_outputs(
218 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None
219) -> torch.Tensor | T | None:
220 """Helper method to concatenate outputs of the model."""
221 if len(outputs) == 1:
222 return outputs[0]
223 if len(outputs) == 0 or outputs[0] is None or OutputClass is None:
224 return None
225 if isinstance(outputs[0], torch.Tensor):
226 return torch.cat(outputs, dim=0)
227 agg = defaultdict(list)
228 types = {}
229 for output in outputs:
230 for key, value in output.items():
231 agg[key].append(value)
232 types[key] = type(value)
233 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()}
234 if OutputClass is BatchEncoding:
235 return OutputClass(kwargs)
236 return OutputClass(**kwargs)
237
238
[docs]
239class BatchEncodingWrapper(Protocol):
240 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
241
242
[docs]
243def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper:
244 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding
245 if the model runs out of memory.
246
247 :param func: Function to wrap that takes a batch encoding
248 :type func: BatchEncodingWrapper
249 :raises e: If CUDA runs out of memory even after lowering the batch size to 1
250 :raises ValueError: If no output was generated
251 :return: Wrapped function
252 :rtype: BatchEncodingWrapper
253 """
254
255 @wraps(func)
256 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:
257 if not self.ALLOW_SUB_BATCHING:
258 return func(self, encoding, *args, **kwargs)
259 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0]
260 sub_encoding = encoding
261 remaining_encoding = encoding
262 OutputClass = None
263 outputs = []
264 while True:
265 try:
266 # ceil division
267 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size)
268 for _ in range(num_batches):
269 sub_encoding = BatchEncoding(
270 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()}
271 )
272 output = func(self, sub_encoding, *args, **kwargs)
273 OutputClass = output.__class__
274 outputs.append(output)
275 remaining_encoding = BatchEncoding(
276 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()}
277 )
278 break
279 except RuntimeError as e:
280 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e):
281 self._sub_batch_size = sub_batch_size = sub_batch_size // 2
282 if sub_batch_size == 0:
283 raise e
284 else:
285 raise e
286 if OutputClass is None:
287 raise ValueError("No output was generated.")
288 return _cat_outputs(outputs, OutputClass)
289
290 return wrapper