Source code for lightning_ir.base.model

  1"""
  2Model module for Lightning IR.
  3
  4This module contains the main model class and output class for the Lightning IR library.
  5"""
  6
  7from collections import defaultdict
  8from dataclasses import dataclass
  9from functools import partial, wraps
 10from pathlib import Path
 11from typing import Any, Literal, Mapping, Protocol, Sequence, Type, TypeVar
 12
 13import torch
 14from transformers import MODEL_MAPPING, BatchEncoding, BertModel
 15from transformers.modeling_outputs import ModelOutput
 16
 17from .._flash import FLASH_ATTENTION_MAP
 18from .class_factory import LightningIRModelClassFactory
 19from .config import LightningIRConfig
 20from .external_model_hub import CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
 21
 22
[docs] 23@dataclass 24class LightningIROutput(ModelOutput): 25 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_. 26 27 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput 28 29 :param scores: Output relevance scores for query--document pairs, defaults to None 30 :type scores: torch.Tensor | None, optional 31 """ 32 33 scores: torch.Tensor | None = None
34 35
[docs] 36class LightningIRModel: 37 """Base class for Lightning IR models. Derived classes implement the forward method for handling query 38 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model. 39 40 .. _transformers.PreTrainedModel: \ 41https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel 42 """ 43 44 config_class: Type[LightningIRConfig] = LightningIRConfig 45 """Configuration class for the model.""" 46 47 ALLOW_SUB_BATCHING = True 48 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure 49 correctness.""" 50
[docs] 51 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None: 52 """Initializes the model. 53 54 :param config: Configuration class for the model 55 :type config: LightningIRConfig 56 """ 57 super().__init__(config, *args, **kwargs) 58 self.config = config 59 60 self._sub_batch_size: int | None = None 61 62 if self.config.backbone_model_type is not None: 63 flash_attn = FLASH_ATTENTION_MAP.get(self.config.backbone_model_type, None) 64 if flash_attn is not None: 65 flash_attn_forward, self_attn_pattern = flash_attn 66 for name, module in self.named_modules(): 67 if name.endswith(self_attn_pattern): 68 module.forward = partial(flash_attn_forward, module)
69 70 def _backbone_forward(self, *args, **kwargs): 71 """Runs the forward method of the backbone model. Is overridden in 72 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`. 73 74 :raises NotImplementedError: If not overridden in the derived class 75 """ 76 raise NotImplementedError 77
[docs] 78 def forward(self, *args, **kwargs) -> LightningIROutput: 79 """Forward method of the model. Must be implemented by the derived class.""" 80 raise NotImplementedError
81 82 def _sparsification( 83 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log"] | None = None 84 ) -> torch.Tensor: 85 """Helper method to apply sparsification to the embeddings. 86 87 :param embeddings: Query or document embeddings 88 :type embeddings: torch.Tensor 89 :param sparsification_strategy: The sparsification strategy. No sparsification is applied if None, 90 defaults to None 91 :type sparsification_strategy: Literal['relu', 'relu_log'] | None, optional 92 :raises ValueError: If an unknown sparsification strategy is passed 93 :return: (Optionally) sparsified embeddings 94 :rtype: torch.Tensor 95 """ 96 if sparsification_strategy is None: 97 return embeddings 98 if sparsification_strategy == "relu": 99 return torch.relu(embeddings) 100 if sparsification_strategy == "relu_log": 101 return torch.log1p(torch.relu(embeddings)) 102 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}") 103 104 def _pooling( 105 self, 106 embeddings: torch.Tensor, 107 attention_mask: torch.Tensor | None, 108 pooling_strategy: Literal["first", "mean", "max", "sum"] | None, 109 ) -> torch.Tensor: 110 """Helper method to apply pooling to the embeddings. 111 112 :param embeddings: Query or document embeddings 113 :type embeddings: torch.Tensor 114 :param attention_mask: Query or document attention mask 115 :type attention_mask: torch.Tensor | None 116 :param pooling_strategy: The pooling strategy. No pooling is applied if None. 117 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None 118 :raises ValueError: If an unknown pooling strategy is passed 119 :return: (Optionally) pooled embeddings 120 :rtype: torch.Tensor 121 """ 122 if pooling_strategy is None: 123 return embeddings 124 if pooling_strategy == "first": 125 return embeddings[:, [0]] 126 if pooling_strategy in ("sum", "mean"): 127 if attention_mask is not None: 128 embeddings = embeddings * attention_mask.unsqueeze(-1) 129 embeddings = embeddings.sum(dim=1, keepdim=True) 130 if pooling_strategy == "mean": 131 if attention_mask is not None: 132 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1) 133 return embeddings 134 if pooling_strategy == "max": 135 if attention_mask is not None: 136 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), -1e9) 137 return embeddings.max(dim=1, keepdim=True).values 138 raise ValueError(f"Unknown pooling strategy: {pooling_strategy}") 139 140 @classmethod 141 def _load_pretrained_model( 142 cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 143 ): 144 if pretrained_model_name_or_path in STATE_DICT_KEY_MAPPING: 145 map_keys = STATE_DICT_KEY_MAPPING[pretrained_model_name_or_path] 146 for orig_key, new_key in map_keys: 147 if orig_key is not None: 148 state_dict[new_key] = state_dict.pop(orig_key) 149 loaded_keys[loaded_keys.index(orig_key)] = new_key 150 else: 151 loaded_keys.append(new_key) 152 model, *out = super()._load_pretrained_model( 153 model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, *args, **kwargs 154 ) 155 if pretrained_model_name_or_path in POST_LOAD_CALLBACKS: 156 model = POST_LOAD_CALLBACKS[pretrained_model_name_or_path](model) 157 return (model, *out) 158
[docs] 159 @classmethod 160 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRModel": 161 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method and to return a 162 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. 163 164.. _transformers.PreTrainedModel.from_pretrained: \ 165 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 166 167 :param model_name_or_path: Name or path of the pretrained model 168 :type model_name_or_path: str | Path 169 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed 170 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin 171 :rtype: LightningIRModel 172 173 .. ::doctest 174 .. highlight:: python 175 .. code-block:: python 176 177 >>> # Loading using model class and backbone checkpoint 178 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 179 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 180 >>> # Loading using base class and backbone checkpoint 181 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 182 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 183 """ 184 # provides AutoModel.from_pretrained support 185 config = kwargs.get("config", None) 186 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__): 187 # no backbone models found, create derived lightning-ir model based on backbone model 188 if config is not None: 189 config_class = config.__class__ 190 elif model_name_or_path in CHECKPOINT_MAPPING: 191 _config = CHECKPOINT_MAPPING[model_name_or_path] 192 config_class = _config.__class__ 193 if config is None: 194 config = _config 195 elif cls is not LightningIRModel: 196 config_class = cls.config_class 197 else: 198 config_class = LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path) 199 if config_class is None: 200 raise ValueError("Pass a config to `from_pretrained`.") 201 BackboneConfig = LightningIRModelClassFactory.get_backbone_config(model_name_or_path) 202 BackboneModel = MODEL_MAPPING[BackboneConfig] 203 cls = LightningIRModelClassFactory(config_class).from_backbone_class(BackboneModel) 204 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 205 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config) 206 derived_config.update(config.to_dict()) 207 kwargs["config"] = derived_config 208 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 209 if issubclass(cls, BertModel): 210 kwargs["add_pooling_layer"] = False 211 return super(LightningIRModel, cls).from_pretrained(model_name_or_path, *args, **kwargs)
212 213 214T = TypeVar("T") 215 216 217def _cat_outputs( 218 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None 219) -> torch.Tensor | T | None: 220 """Helper method to concatenate outputs of the model.""" 221 if len(outputs) == 1: 222 return outputs[0] 223 if len(outputs) == 0 or outputs[0] is None or OutputClass is None: 224 return None 225 if isinstance(outputs[0], torch.Tensor): 226 return torch.cat(outputs, dim=0) 227 agg = defaultdict(list) 228 types = {} 229 for output in outputs: 230 for key, value in output.items(): 231 agg[key].append(value) 232 types[key] = type(value) 233 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()} 234 if OutputClass is BatchEncoding: 235 return OutputClass(kwargs) 236 return OutputClass(**kwargs) 237 238
[docs] 239class BatchEncodingWrapper(Protocol): 240 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
241 242
[docs] 243def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper: 244 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding 245 if the model runs out of memory. 246 247 :param func: Function to wrap that takes a batch encoding 248 :type func: BatchEncodingWrapper 249 :raises e: If CUDA runs out of memory even after lowering the batch size to 1 250 :raises ValueError: If no output was generated 251 :return: Wrapped function 252 :rtype: BatchEncodingWrapper 253 """ 254 255 @wraps(func) 256 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any: 257 if not self.ALLOW_SUB_BATCHING: 258 return func(self, encoding, *args, **kwargs) 259 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0] 260 sub_encoding = encoding 261 remaining_encoding = encoding 262 OutputClass = None 263 outputs = [] 264 while True: 265 try: 266 # ceil division 267 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size) 268 for _ in range(num_batches): 269 sub_encoding = BatchEncoding( 270 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()} 271 ) 272 output = func(self, sub_encoding, *args, **kwargs) 273 OutputClass = output.__class__ 274 outputs.append(output) 275 remaining_encoding = BatchEncoding( 276 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()} 277 ) 278 break 279 except RuntimeError as e: 280 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e): 281 self._sub_batch_size = sub_batch_size = sub_batch_size // 2 282 if sub_batch_size == 0: 283 raise e 284 else: 285 raise e 286 if OutputClass is None: 287 raise ValueError("No output was generated.") 288 return _cat_outputs(outputs, OutputClass) 289 290 return wrapper