1import codecs
2import json
3from pathlib import Path
4from typing import Any, Dict, Literal, NamedTuple, Tuple, Type
5
6import ir_datasets
7from ir_datasets.datasets.base import Dataset
8from ir_datasets.formats import BaseDocPairs, jsonl, trec, tsv
9from ir_datasets.util import Cache, DownloadConfig, GzipExtract
10
11CONSTITUENT_TYPE_MAP: Dict[str, Dict[str, Type]] = {
12 "docs": {
13 ".json": jsonl.JsonlDocs,
14 ".jsonl": jsonl.JsonlDocs,
15 ".tsv": tsv.TsvDocs,
16 },
17 "queries": {
18 ".json": jsonl.JsonlQueries,
19 ".jsonl": jsonl.JsonlQueries,
20 ".tsv": tsv.TsvQueries,
21 },
22 "qrels": {".tsv": trec.TrecQrels, ".qrels": trec.TrecQrels},
23 "scoreddocs": {".run": trec.TrecScoredDocs, ".tsv": trec.TrecScoredDocs},
24 "docpairs": {".tsv": tsv.TsvDocPairs},
25}
26
27
28def _load_constituent(
29 constituent: Path | str | None,
30 constituent_type: Literal["docs", "queries", "qrels", "scoreddocs", "docpairs"],
31 **kwargs,
32) -> Any:
33 if constituent is None:
34 return None
35 if constituent in ir_datasets.registry._registered:
36 return getattr(ir_datasets.load(constituent), f"{constituent_type}_handler")
37 constituent_path = Path(constituent)
38 if not constituent_path.exists():
39 raise ValueError(f"unable to load {constituent}, expected an `ir_datasets` id or valid path")
40 suffix = constituent_path.suffixes[0]
41 constituent_types = CONSTITUENT_TYPE_MAP[constituent_type]
42 if suffix not in constituent_types:
43 raise ValueError(f"Unknown file type: {suffix}, expected one of {constituent_types.keys()}")
44 ConstituentType = constituent_types[suffix]
45 return ConstituentType(Cache(None, constituent_path), **kwargs)
46
47
48def _register_local_dataset(
49 dataset_id: str,
50 docs: Path | str | None = None,
51 queries: Path | str | None = None,
52 qrels: Path | str | None = None,
53 docpairs: Path | str | None = None,
54 scoreddocs: Path | str | None = None,
55 qrels_defs: Dict[int, str] | None = None,
56):
57 if dataset_id in ir_datasets.registry._registered:
58 return
59
60 docs = _load_constituent(docs, "docs")
61 queries = _load_constituent(queries, "queries")
62 qrels = _load_constituent(qrels, "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {})
63 docpairs = _load_constituent(docpairs, "docpairs")
64 scoreddocs = _load_constituent(scoreddocs, "scoreddocs")
65
66 ir_datasets.registry.register(dataset_id, Dataset(docs, queries, qrels, docpairs, scoreddocs))
67
68
[docs]
69class ScoredDocTuple(NamedTuple):
70 query_id: str
71 doc_ids: Tuple[str, ...]
72 scores: Tuple[float, ...] | None
73 num_docs: int
74
75
[docs]
76class ScoredDocTuples(BaseDocPairs):
[docs]
77 def __init__(self, docpairs_dlc):
78 self._docpairs_dlc = docpairs_dlc
79
80 def docpairs_path(self):
81 return self._docpairs_dlc.path()
82
83 def docpairs_iter(self):
84 file_type = None
85 if self._docpairs_dlc.path().suffix == ".json":
86 file_type = "json"
87 elif self._docpairs_dlc.path().suffix in (".tsv", ".run"):
88 file_type = "tsv"
89 else:
90 raise ValueError(f"Unknown file type: {self._docpairs_dlc.path().suffix}")
91 with self._docpairs_dlc.stream() as f:
92 f = codecs.getreader("utf8")(f)
93 for line in f:
94 if file_type == "json":
95 data = json.loads(line)
96 qid, *doc_data = data
97 pids, scores = zip(*doc_data)
98 pids = tuple(str(pid) for pid in pids)
99 else:
100 cols = line.rstrip().split()
101 pos_score, neg_score, qid, pid1, pid2 = cols
102 pids = (pid1, pid2)
103 scores = (float(pos_score), float(neg_score))
104 yield ScoredDocTuple(str(qid), pids, scores, len(pids))
105
106 def docpairs_cls(self):
107 return ScoredDocTuple
108
109
110def _register_kd_docpairs():
111 base_id = "msmarco-passage"
112 split_id = "train"
113 file_id = "kd-docpairs"
114 cache_path = "bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv"
115 dlc_contents = {
116 "url": (
117 "https://zenodo.org/record/4068216/files/bert_cat_ensemble_"
118 "msmarcopassage_train_scores_ids.tsv?download=1"
119 ),
120 "expected_md5": "4d99696386f96a7f1631076bcc53ac3c",
121 "cache_path": cache_path,
122 }
123 file_name = f"{split_id}/{file_id}.tsv"
124 register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples)
125
126
127def _register_colbert_docpairs():
128 base_id = "msmarco-passage"
129 split_id = "train"
130 file_id = "colbert-docpairs"
131 cache_path = "colbert_64way.json"
132 dlc_contents = {
133 "url": (
134 "https://huggingface.co/colbert-ir/colbertv2.0_msmarco_64way/" "resolve/main/examples.json?download=true"
135 ),
136 "expected_md5": "8be0c71e330ac54dcd77fba058d291c7",
137 "cache_path": cache_path,
138 }
139 file_name = f"{split_id}/{file_id}.json"
140 register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples)
141
142
143def _register_rank_distillm():
144 base_id = "msmarco-passage"
145 split_id = "train"
146 file_id = "rank-distillm/rankzephyr"
147 cache_path = "rank-distillm-rankzephyr.run"
148 dlc_contents = {
149 "url": (
150 "https://zenodo.org/records/12528410/files/__rankzephyr-colbert-10000-"
151 "sampled-100__msmarco-passage-train-judged.run?download=1"
152 ),
153 "expected_md5": "49f8dbf2c1ee7a2ca1fe517eda528af6",
154 "cache_path": cache_path,
155 }
156 file_name = f"{split_id}/{file_id}.run"
157 register_msmarco(
158 base_id,
159 split_id,
160 file_id,
161 cache_path,
162 dlc_contents,
163 file_name,
164 trec.TrecScoredDocs,
165 )
166
167 file_id = "rank-distillm/set-encoder"
168 cache_path = "rank-distillm-set-encoder.run.gz"
169 dlc_contents = {
170 "url": (
171 "https://zenodo.org/records/12528410/files/__set-encoder-colbert__"
172 "msmarco-passage-train-judged.run.gz?download=1"
173 ),
174 "expected_md5": "1f069d0daa9842a54a858cc660149e1a",
175 "cache_path": cache_path,
176 }
177 file_name = f"{split_id}/{file_id}.run"
178 register_msmarco(
179 base_id,
180 split_id,
181 file_id,
182 cache_path,
183 dlc_contents,
184 file_name,
185 trec.TrecScoredDocs,
186 extract=True,
187 )
188
189
[docs]
190def register_msmarco(
191 base_id: str,
192 split_id: str,
193 file_id: str,
194 cache_path: str,
195 dlc_contents: Dict[str, Any],
196 file_name: str,
197 ConstituentType: Type,
198 extract: bool = False,
199):
200 dataset_id = f"{base_id}/{split_id}/{file_id}"
201 if dataset_id in ir_datasets.registry._registered:
202 return
203 base_path = ir_datasets.util.home_path() / base_id
204 dlc = DownloadConfig.context(base_id, base_path)
205 dlc._contents[cache_path] = dlc_contents
206 ir_dataset = ir_datasets.load(f"{base_id}/{split_id}")
207 collection = ir_dataset.docs_handler()
208 queries = ir_dataset.queries_handler()
209 qrels = ir_dataset.qrels_handler()
210 _dlc = dlc[cache_path]
211 if extract:
212 _dlc = GzipExtract(_dlc)
213 constituent = ConstituentType(Cache(_dlc, base_path / split_id / file_name))
214 dataset = Dataset(collection, queries, qrels, constituent)
215 ir_datasets.registry.register(dataset_id, Dataset(dataset))
216
217
218_register_kd_docpairs()
219_register_colbert_docpairs()
220_register_rank_distillm()