Source code for lightning_ir.data.ir_datasets_utils

  1import codecs
  2import json
  3from pathlib import Path
  4from typing import Any, Dict, Literal, NamedTuple, Tuple, Type
  5
  6import ir_datasets
  7from ir_datasets.datasets.base import Dataset
  8from ir_datasets.formats import BaseDocPairs, jsonl, trec, tsv
  9from ir_datasets.util import Cache, DownloadConfig, GzipExtract
 10
 11CONSTITUENT_TYPE_MAP: Dict[str, Dict[str, Type]] = {
 12    "docs": {
 13        ".json": jsonl.JsonlDocs,
 14        ".jsonl": jsonl.JsonlDocs,
 15        ".tsv": tsv.TsvDocs,
 16    },
 17    "queries": {
 18        ".json": jsonl.JsonlQueries,
 19        ".jsonl": jsonl.JsonlQueries,
 20        ".tsv": tsv.TsvQueries,
 21    },
 22    "qrels": {".tsv": trec.TrecQrels, ".qrels": trec.TrecQrels},
 23    "scoreddocs": {".run": trec.TrecScoredDocs, ".tsv": trec.TrecScoredDocs},
 24    "docpairs": {".tsv": tsv.TsvDocPairs},
 25}
 26
 27
 28def _load_constituent(
 29    constituent: Path | str | None,
 30    constituent_type: Literal["docs", "queries", "qrels", "scoreddocs", "docpairs"],
 31    **kwargs,
 32) -> Any:
 33    if constituent is None:
 34        return None
 35    if constituent in ir_datasets.registry._registered:
 36        return getattr(ir_datasets.load(constituent), f"{constituent_type}_handler")
 37    constituent_path = Path(constituent)
 38    if not constituent_path.exists():
 39        raise ValueError(f"unable to load {constituent}, expected an `ir_datasets` id or valid path")
 40    suffix = constituent_path.suffixes[0]
 41    constituent_types = CONSTITUENT_TYPE_MAP[constituent_type]
 42    if suffix not in constituent_types:
 43        raise ValueError(f"Unknown file type: {suffix}, expected one of {constituent_types.keys()}")
 44    ConstituentType = constituent_types[suffix]
 45    return ConstituentType(Cache(None, constituent_path), **kwargs)
 46
 47
 48def _register_local_dataset(
 49    dataset_id: str,
 50    docs: Path | str | None = None,
 51    queries: Path | str | None = None,
 52    qrels: Path | str | None = None,
 53    docpairs: Path | str | None = None,
 54    scoreddocs: Path | str | None = None,
 55    qrels_defs: Dict[int, str] | None = None,
 56):
 57    if dataset_id in ir_datasets.registry._registered:
 58        return
 59
 60    docs = _load_constituent(docs, "docs")
 61    queries = _load_constituent(queries, "queries")
 62    qrels = _load_constituent(qrels, "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {})
 63    docpairs = _load_constituent(docpairs, "docpairs")
 64    scoreddocs = _load_constituent(scoreddocs, "scoreddocs")
 65
 66    ir_datasets.registry.register(dataset_id, Dataset(docs, queries, qrels, docpairs, scoreddocs))
 67
 68
[docs] 69class ScoredDocTuple(NamedTuple): 70 query_id: str 71 doc_ids: Tuple[str, ...] 72 scores: Tuple[float, ...] | None 73 num_docs: int
74 75
[docs] 76class ScoredDocTuples(BaseDocPairs):
[docs] 77 def __init__(self, docpairs_dlc): 78 self._docpairs_dlc = docpairs_dlc
79 80 def docpairs_path(self): 81 return self._docpairs_dlc.path() 82 83 def docpairs_iter(self): 84 file_type = None 85 if self._docpairs_dlc.path().suffix == ".json": 86 file_type = "json" 87 elif self._docpairs_dlc.path().suffix in (".tsv", ".run"): 88 file_type = "tsv" 89 else: 90 raise ValueError(f"Unknown file type: {self._docpairs_dlc.path().suffix}") 91 with self._docpairs_dlc.stream() as f: 92 f = codecs.getreader("utf8")(f) 93 for line in f: 94 if file_type == "json": 95 data = json.loads(line) 96 qid, *doc_data = data 97 pids, scores = zip(*doc_data) 98 pids = tuple(str(pid) for pid in pids) 99 else: 100 cols = line.rstrip().split() 101 pos_score, neg_score, qid, pid1, pid2 = cols 102 pids = (pid1, pid2) 103 scores = (float(pos_score), float(neg_score)) 104 yield ScoredDocTuple(str(qid), pids, scores, len(pids)) 105 106 def docpairs_cls(self): 107 return ScoredDocTuple
108 109 110def _register_kd_docpairs(): 111 base_id = "msmarco-passage" 112 split_id = "train" 113 file_id = "kd-docpairs" 114 cache_path = "bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv" 115 dlc_contents = { 116 "url": ( 117 "https://zenodo.org/record/4068216/files/bert_cat_ensemble_" 118 "msmarcopassage_train_scores_ids.tsv?download=1" 119 ), 120 "expected_md5": "4d99696386f96a7f1631076bcc53ac3c", 121 "cache_path": cache_path, 122 } 123 file_name = f"{split_id}/{file_id}.tsv" 124 register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples) 125 126 127def _register_colbert_docpairs(): 128 base_id = "msmarco-passage" 129 split_id = "train" 130 file_id = "colbert-docpairs" 131 cache_path = "colbert_64way.json" 132 dlc_contents = { 133 "url": ( 134 "https://huggingface.co/colbert-ir/colbertv2.0_msmarco_64way/" "resolve/main/examples.json?download=true" 135 ), 136 "expected_md5": "8be0c71e330ac54dcd77fba058d291c7", 137 "cache_path": cache_path, 138 } 139 file_name = f"{split_id}/{file_id}.json" 140 register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples) 141 142 143def _register_rank_distillm(): 144 base_id = "msmarco-passage" 145 split_id = "train" 146 file_id = "rank-distillm/rankzephyr" 147 cache_path = "rank-distillm-rankzephyr.run" 148 dlc_contents = { 149 "url": ( 150 "https://zenodo.org/records/12528410/files/__rankzephyr-colbert-10000-" 151 "sampled-100__msmarco-passage-train-judged.run?download=1" 152 ), 153 "expected_md5": "49f8dbf2c1ee7a2ca1fe517eda528af6", 154 "cache_path": cache_path, 155 } 156 file_name = f"{split_id}/{file_id}.run" 157 register_msmarco( 158 base_id, 159 split_id, 160 file_id, 161 cache_path, 162 dlc_contents, 163 file_name, 164 trec.TrecScoredDocs, 165 ) 166 167 file_id = "rank-distillm/set-encoder" 168 cache_path = "rank-distillm-set-encoder.run.gz" 169 dlc_contents = { 170 "url": ( 171 "https://zenodo.org/records/12528410/files/__set-encoder-colbert__" 172 "msmarco-passage-train-judged.run.gz?download=1" 173 ), 174 "expected_md5": "1f069d0daa9842a54a858cc660149e1a", 175 "cache_path": cache_path, 176 } 177 file_name = f"{split_id}/{file_id}.run" 178 register_msmarco( 179 base_id, 180 split_id, 181 file_id, 182 cache_path, 183 dlc_contents, 184 file_name, 185 trec.TrecScoredDocs, 186 extract=True, 187 ) 188 189
[docs] 190def register_msmarco( 191 base_id: str, 192 split_id: str, 193 file_id: str, 194 cache_path: str, 195 dlc_contents: Dict[str, Any], 196 file_name: str, 197 ConstituentType: Type, 198 extract: bool = False, 199): 200 dataset_id = f"{base_id}/{split_id}/{file_id}" 201 if dataset_id in ir_datasets.registry._registered: 202 return 203 base_path = ir_datasets.util.home_path() / base_id 204 dlc = DownloadConfig.context(base_id, base_path) 205 dlc._contents[cache_path] = dlc_contents 206 ir_dataset = ir_datasets.load(f"{base_id}/{split_id}") 207 collection = ir_dataset.docs_handler() 208 queries = ir_dataset.queries_handler() 209 qrels = ir_dataset.qrels_handler() 210 _dlc = dlc[cache_path] 211 if extract: 212 _dlc = GzipExtract(_dlc) 213 constituent = ConstituentType(Cache(_dlc, base_path / split_id / file_name)) 214 dataset = Dataset(collection, queries, qrels, constituent) 215 ir_datasets.registry.register(dataset_id, Dataset(dataset))
216 217 218_register_kd_docpairs() 219_register_colbert_docpairs() 220_register_rank_distillm()