Source code for lightning_ir.retrieve.plaid.packed_tensor
1from pathlib import Path
2from typing import Sequence, Tuple
3
4import torch
5
6
[docs]
7class PackedTensor:
[docs]
8 def __init__(self, packed_tensor: torch.Tensor, lengths: Sequence[int]) -> None:
9 self.packed_tensor = packed_tensor
10 self.lengths = list(lengths)
11 self._segmented_tensor: Tuple[torch.Tensor, ...] | None = None
12
13 def __repr__(self) -> str:
14 return f"PackedTensor(packed_tensor={self.packed_tensor}, lengths={self.lengths})"
15
16 def __str__(self) -> str:
17 return self.__repr__()
18
19 @property
20 def segmented_tensor(self) -> Tuple[torch.Tensor, ...]:
21 if self._segmented_tensor is None:
22 self._segmented_tensor = torch.split(self.packed_tensor, self.lengths)
23 return self._segmented_tensor
24
25 def lookup(
26 self, packed_idcs: torch.Tensor, idcs_lengths: Sequence[int] | int, unique: bool = False
27 ) -> "PackedTensor":
28 output_tensors = []
29 lengths = []
30 for lookup_idcs in torch.split(packed_idcs, idcs_lengths):
31 intermediate_tensors = []
32 for idx in lookup_idcs:
33 intermediate_tensors.append(self.segmented_tensor[idx])
34
35 cat_tensors = torch.cat(intermediate_tensors)
36 if unique:
37 cat_tensors = torch.unique(cat_tensors)
38 lengths.append(cat_tensors.shape[0])
39 output_tensors.append(cat_tensors)
40
41 return PackedTensor(torch.cat(output_tensors), lengths)
42
43 def to_padded_tensor(self, pad_value: int = 0) -> torch.Tensor:
44 return torch.nn.utils.rnn.pad_sequence(self.segmented_tensor, batch_first=True, padding_value=pad_value)