Source code for lightning_ir.retrieve.base.indexer
1from __future__ import annotations
2
3import array
4import json
5from abc import ABC, abstractmethod
6from pathlib import Path
7from typing import TYPE_CHECKING, List, Type
8
9import torch
10
11if TYPE_CHECKING:
12 from ...bi_encoder import BiEncoderConfig, BiEncoderOutput
13 from ...data import IndexBatch
14
15
[docs]
16class Indexer(ABC):
[docs]
17 def __init__(
18 self,
19 index_dir: Path,
20 index_config: IndexConfig,
21 bi_encoder_config: BiEncoderConfig,
22 verbose: bool = False,
23 ) -> None:
24 self.index_dir = index_dir
25 self.index_config = index_config
26 self.bi_encoder_config = bi_encoder_config
27 self.doc_ids: List[str] = []
28 self.doc_lengths = array.array("I")
29 self.num_embeddings = 0
30 self.num_docs = 0
31 self.verbose = verbose
32
33 @abstractmethod
34 def add(self, index_batch: IndexBatch, output: BiEncoderOutput) -> None: ...
35
36 def save(self) -> None:
37 self.index_config.save(self.index_dir)
38 (self.index_dir / "doc_ids.txt").write_text("\n".join(self.doc_ids))
39 doc_lengths = torch.frombuffer(self.doc_lengths, dtype=torch.int32)
40 torch.save(doc_lengths, self.index_dir / "doc_lengths.pt")
41
42
[docs]
43class IndexConfig:
44 indexer_class: Type[Indexer] = Indexer
45
46 @classmethod
47 def from_pretrained(cls, index_dir: Path | str) -> "IndexConfig":
48 index_dir = Path(index_dir)
49 with open(index_dir / "config.json", "r") as f:
50 data = json.load(f)
51 if data["index_type"] != cls.__name__:
52 raise ValueError(f"Expected index_type {cls.__name__}, got {data['index_type']}")
53 data.pop("index_type", None)
54 data.pop("index_dir", None)
55 return cls(**data)
56
57 def save(self, index_dir: Path) -> None:
58 index_dir.mkdir(parents=True, exist_ok=True)
59 with open(index_dir / "config.json", "w") as f:
60 data = self.__dict__.copy()
61 data["index_dir"] = str(index_dir)
62 data["index_type"] = self.__class__.__name__
63 json.dump(data, f)
64
65 def to_dict(self) -> dict:
66 return self.__dict__.copy()