Source code for lightning_ir.retrieve.plaid.packed_tensor

 1from pathlib import Path
 2from typing import Sequence, Tuple
 3
 4import torch
 5
 6
[docs] 7class PackedTensor:
[docs] 8 def __init__(self, packed_tensor: torch.Tensor, lengths: Sequence[int]) -> None: 9 self.packed_tensor = packed_tensor 10 self.lengths = list(lengths) 11 self._segmented_tensor: Tuple[torch.Tensor, ...] | None = None
12 13 def __repr__(self) -> str: 14 return f"PackedTensor(packed_tensor={self.packed_tensor}, lengths={self.lengths})" 15 16 def __str__(self) -> str: 17 return self.__repr__() 18 19 @property 20 def segmented_tensor(self) -> Tuple[torch.Tensor, ...]: 21 if self._segmented_tensor is None: 22 self._segmented_tensor = torch.split(self.packed_tensor, self.lengths) 23 return self._segmented_tensor 24 25 def lookup( 26 self, packed_idcs: torch.Tensor, idcs_lengths: Sequence[int] | int, unique: bool = False 27 ) -> "PackedTensor": 28 output_tensors = [] 29 lengths = [] 30 for lookup_idcs in torch.split(packed_idcs, idcs_lengths): 31 intermediate_tensors = [] 32 for idx in lookup_idcs: 33 intermediate_tensors.append(self.segmented_tensor[idx]) 34 35 cat_tensors = torch.cat(intermediate_tensors) 36 if unique: 37 cat_tensors = torch.unique(cat_tensors) 38 lengths.append(cat_tensors.shape[0]) 39 output_tensors.append(cat_tensors) 40 41 return PackedTensor(torch.cat(output_tensors), lengths) 42 43 def to_padded_tensor(self, pad_value: int = 0) -> torch.Tensor: 44 return torch.nn.utils.rnn.pad_sequence(self.segmented_tensor, batch_first=True, padding_value=pad_value)