Source code for lightning_ir.data.dataset

  1"""
  2Datasets for Lightning IR that data loading and sampling.
  3
  4This module defines several datasets that handle loading and sampling data for training and inference.
  5"""
  6
  7import csv
  8import warnings
  9from itertools import islice
 10from pathlib import Path
 11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
 12
 13import ir_datasets
 14import numpy as np
 15import pandas as pd
 16import torch
 17from ir_datasets.formats import GenericDoc, GenericDocPair
 18from torch.distributed import get_rank, get_world_size
 19from torch.utils.data import Dataset, IterableDataset, get_worker_info
 20
 21from .data import DocSample, QuerySample, RankSample
 22from .ir_datasets_utils import ScoredDocTuple
 23
 24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
 25
 26
 27class _IRDataset:
 28    def __init__(self, dataset: str) -> None:
 29        super().__init__()
 30        self._dataset = dataset
 31        self._queries = None
 32        self._docs = None
 33        self._qrels = None
 34
 35    @property
 36    def dataset(self) -> str:
 37        """Dataset name.
 38
 39        :return: Dataset name
 40        :rtype: str
 41        """
 42        return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset)
 43
 44    @property
 45    def dataset_id(self) -> str:
 46        """Dataset id.
 47
 48        :return: Dataset id
 49        :rtype: str
 50        """
 51        if self.ir_dataset is None:
 52            return self.dataset
 53        return self.ir_dataset.dataset_id()
 54
 55    @property
 56    def docs_dataset_id(self) -> str:
 57        """ID of the dataset containing the documents.
 58
 59        :return: Document dataset id
 60        :rtype: str
 61        """
 62        return ir_datasets.docs_parent_id(self.dataset_id)
 63
 64    @property
 65    def ir_dataset(self) -> ir_datasets.Dataset | None:
 66        """Instance of ir_datasets.Dataset.
 67
 68        :return: ir_datasets dataset
 69        :rtype: ir_datasets.Dataset | None
 70        """
 71        try:
 72            return ir_datasets.load(self.dataset)
 73        except KeyError:
 74            return None
 75
 76    @property
 77    def DASHED_DATASET_MAP(self) -> Dict[str, str]:
 78        """Map of dataset names with dashes to dataset names with slashes.
 79
 80        :return: Dataset map
 81        :rtype: Dict[str, str]
 82        """
 83        return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
 84
 85    @property
 86    def queries(self) -> pd.Series:
 87        """Queries in the dataset.
 88
 89        :raises ValueError: If no queries are found in the dataset
 90        :return: Queries
 91        :rtype: pd.Series
 92        """
 93        if self._queries is None:
 94            if self.ir_dataset is None:
 95                raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
 96            queries_iter = self.ir_dataset.queries_iter()
 97            self._queries = pd.Series(
 98                {query.query_id: query.default_text() for query in queries_iter},
 99                name="text",
100            )
101            self._queries.index.name = "query_id"
102        return self._queries
103
104    @property
105    def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]:
106        """Documents in the dataset.
107
108        :raises ValueError: If no documents are found in the dataset
109        :return: Documents
110        :rtype: ir_datasets.indices.Docstore | Dict[str, GenericDoc]
111        """
112        if self._docs is None:
113            if self.ir_dataset is None:
114                raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
115            self._docs = self.ir_dataset.docs_store()
116        return self._docs
117
118    @property
119    def qrels(self) -> pd.DataFrame | None:
120        """Qrels in the dataset.
121
122        :return: Qrels
123        :rtype: pd.DataFrame | None
124        """
125        if self._qrels is not None:
126            return self._qrels
127        if self.ir_dataset is None:
128            return None
129        qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
130        if "iteration" not in qrels.columns:
131            qrels["iteration"] = 0
132        qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
133        qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
134        qrels = qrels.droplevel(0, axis=1)
135        self._qrels = qrels
136        return self._qrels
137
138
139class _DataParallelIterableDataset(IterableDataset):
140    # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
141    def __init__(self) -> None:
142        super().__init__()
143        # TODO add support for multi-gpu and multi-worker inference; currently
144        # doesn't work
145        worker_info = get_worker_info()
146        num_workers = worker_info.num_workers if worker_info is not None else 1
147        worker_id = worker_info.id if worker_info is not None else 0
148
149        try:
150            world_size = get_world_size()
151            process_rank = get_rank()
152        except (RuntimeError, ValueError):
153            world_size = 1
154            process_rank = 0
155
156        self.num_replicas = num_workers * world_size
157        self.rank = process_rank * num_workers + worker_id
158
159
[docs] 160class QueryDataset(_IRDataset, _DataParallelIterableDataset):
[docs] 161 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None: 162 """Dataset containing queries. 163 164 :param query_dataset: Path to file containing queries or valid ir_datasets id 165 :type query_dataset: str 166 :param num_queries: Number of queries in dataset. If None, the number of queries will attempted to be inferred, 167 defaults to None 168 :type num_queries: int | None, optional 169 """ 170 super().__init__(query_dataset) 171 super(_IRDataset, self).__init__() 172 self.num_queries = num_queries
173 174 def __len__(self) -> int: 175 """Number of queries in the dataset. 176 177 :return: Number of queries 178 :rtype: int 179 """ 180 # TODO fix len for multi-gpu and multi-worker inference 181 return self.num_queries or self.ir_dataset.queries_count() 182 183 def __iter__(self) -> Iterator[QuerySample]: 184 """Iterate over queries in the dataset. 185 186 :yield: Query sample 187 :rtype: Iterator[QuerySample] 188 """ 189 start = self.rank 190 stop = self.num_queries 191 step = self.num_replicas 192 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step): 193 query_sample = QuerySample.from_ir_dataset_sample(sample) 194 if self.qrels is not None: 195 qrels = ( 196 self.qrels.loc[[query_sample.query_id]] 197 .stack() 198 .rename("relevance") 199 .astype(int) 200 .reset_index() 201 .to_dict(orient="records") 202 ) 203 query_sample.qrels = qrels 204 yield query_sample
205 206
[docs] 207class DocDataset(_IRDataset, _DataParallelIterableDataset):
[docs] 208 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None: 209 """Dataset containing documents. 210 211 :param doc_dataset: Path to file containing documents or valid ir_datasets id 212 :type doc_dataset: str 213 :param num_docs: Number of documents in dataset. If None, the number of documents will attempted to be inferred, 214 defaults to None 215 :type num_docs: int | None, optional 216 :param text_fields: Fields to parse the document text from, defaults to None 217 :type text_fields: Sequence[str] | None, optional 218 """ 219 super().__init__(doc_dataset) 220 super(_IRDataset, self).__init__() 221 self.num_docs = num_docs 222 self.text_fields = text_fields
223 224 def __len__(self) -> int: 225 """Number of documents in the dataset. 226 227 :raises ValueError: If no `num_docs` was not provided in the constructor and the number of documents cannot 228 be inferred 229 :return: Number of documents 230 :rtype: int 231 """ 232 # TODO fix len for multi-gpu and multi-worker inference 233 num_docs = self.num_docs or self.ir_dataset.docs_count() 234 if num_docs is None: 235 raise ValueError("Unable to determine number of documents.") 236 return num_docs 237 238 def __iter__(self) -> Iterator[DocSample]: 239 """Iterate over documents in the dataset. 240 241 :yield: Doc sample 242 :rtype: Iterator[DocSample] 243 """ 244 start = self.rank 245 stop = self.num_docs 246 step = self.num_replicas 247 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step): 248 yield DocSample.from_ir_dataset_sample(sample, self.text_fields)
249 250
[docs] 251class Sampler: 252 """Helper class for sampling subsets of documents from a ranked list.""" 253
[docs] 254 @staticmethod 255 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 256 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 257 are non-relevant. 258 259 :param documents: Ranked list of documents 260 :type documents: pd.DataFrame 261 :param sample_size: Number of documents to sample 262 :type sample_size: int 263 :return: Sampled documents 264 :rtype: pd.DataFrame 265 """ 266 relevance = documents.filter(like="relevance").max(axis=1).fillna(0) 267 relevant = documents.loc[relevance.gt(0)].sample(1) 268 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna() 269 num_non_relevant = non_relevant_bool.sum() 270 sample_non_relevant = min(sample_size - 1, num_non_relevant) 271 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant) 272 return pd.concat([relevant, non_relevant])
273
[docs] 274 @staticmethod 275 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 276 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1`` 277 are non-relevant. 278 279 :param documents: Ranked list of documents 280 :type documents: pd.DataFrame 281 :param sample_size: Number of documents to sample 282 :type sample_size: int 283 :return: Sampled documents 284 :rtype: pd.DataFrame 285 """ 286 return documents.head(sample_size)
287
[docs] 288 @staticmethod 289 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 290 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the 291 other half randomly. 292 293 :param documents: Ranked list of documents 294 :type documents: pd.DataFrame 295 :param sample_size: Number of documents to sample 296 :type sample_size: int 297 :return: Sampled documents 298 :rtype: pd.DataFrame 299 """ 300 top_size = sample_size // 2 301 random_size = sample_size - top_size 302 top = documents.head(top_size) 303 random = documents.iloc[top_size:].sample(random_size) 304 return pd.concat([top, random])
305
[docs] 306 @staticmethod 307 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 308 """Sampling strategy to randomly sample ``sample_size`` documents. 309 310 :param documents: Ranked list of documents 311 :type documents: pd.DataFrame 312 :param sample_size: Number of documents to sample 313 :type sample_size: int 314 :return: Sampled documents 315 :rtype: pd.DataFrame 316 """ 317 return documents.sample(sample_size)
318
[docs] 319 @staticmethod 320 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame: 321 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of 322 the ranking. 323 324 :param documents: Ranked list of documents 325 :type documents: pd.DataFrame 326 :param sample_size: Number of documents to sample 327 :type sample_size: int 328 :return: Sampled documents 329 :rtype: pd.DataFrame 330 """ 331 weights = 1 / np.log1p(documents["rank"]) 332 weights[weights.isna()] = weights.min() 333 return documents.sample(sample_size, weights=weights)
334
[docs] 335 @staticmethod 336 def sample( 337 df: pd.DataFrame, 338 sample_size: int, 339 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], 340 ) -> pd.DataFrame: 341 """ 342 Samples a subset of documents from a ranked list given a sampling_strategy. 343 344 :param documents: Ranked list of documents 345 :type documents: pd.DataFrame 346 :param sample_size: Number of documents to sample 347 :type sample_size: int 348 :return: Sampled documents 349 :rtype: pd.DataFrame 350 """ 351 if sample_size == -1: 352 return df 353 if hasattr(Sampler, sampling_strategy): 354 return getattr(Sampler, sampling_strategy)(df, sample_size) 355 raise ValueError("Invalid sampling strategy.")
356 357
[docs] 358class RunDataset(_IRDataset, Dataset):
[docs] 359 def __init__( 360 self, 361 run_path_or_id: Path | str, 362 depth: int = -1, 363 sample_size: int = -1, 364 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top", 365 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None, 366 normalize_targets: bool = False, 367 add_docs_not_in_ranking: bool = False, 368 ) -> None: 369 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list 370 can be sampled using different sampling strategies. 371 372 :param run_path_or_id: Path to a run file or valid ir_datasets id 373 :type run_path_or_id: Path | str 374 :param depth: Depth at which to cut off the ranking. If -1 the full ranking is kept, defaults to -1 375 :type depth: int, optional 376 :param sample_size: The number of documents to sample per query, defaults to -1 377 :type sample_size: int, optional 378 :param sampling_strategy: The sample strategy to use to sample documents, defaults to "top" 379 :type sampling_strategy: Literal['single_relevant', 'top', 'random', 'log_random', 'top_and_random'], optional 380 :param targets: The data type to use as targets for a model during fine-tuning. If relevance the relevance 381 judgements are parsed from qrels, defaults to None 382 :type targets: Literal['relevance', 'subtopic_relevance', 'rank', 'score'] | None, optional 383 :param normalize_targets: Whether to normalize the targets between 0 and 1, defaults to False 384 :type normalize_targets: bool, optional 385 :param add_docs_not_in_ranking: Whether to add relevant to a sample that are in the qrels but not in the 386 ranking, defaults to False 387 :type add_docs_not_in_ranking: bool, optional 388 """ 389 self.run_path = None 390 if Path(run_path_or_id).is_file(): 391 self.run_path = Path(run_path_or_id) 392 dataset = self.run_path.name.split(".")[0].split("__")[-1] 393 else: 394 dataset = str(run_path_or_id) 395 super().__init__(dataset) 396 self.depth = depth 397 self.sample_size = sample_size 398 self.sampling_strategy = sampling_strategy 399 self.targets = targets 400 self.normalize_targets = normalize_targets 401 self.add_docs_not_in_ranking = add_docs_not_in_ranking 402 403 if self.sampling_strategy == "top" and self.sample_size > self.depth: 404 warnings.warn( 405 "Sample size is greater than depth and top sampling strategy is used. " 406 "This can cause documents to be sampled that are not contained " 407 "in the run file, but that are present in the qrels." 408 ) 409 410 self.run: pd.DataFrame | None = None
411 412 def _setup(self): 413 if self.run is not None: 414 return 415 self.run = self._load_run() 416 self.run = self.run.drop_duplicates(["query_id", "doc_id"]) 417 418 if self.qrels is not None: 419 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates()) 420 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique() 421 query_ids = run_query_ids.intersection(qrels_query_ids) 422 if len(run_query_ids.difference(qrels_query_ids)): 423 self.run = self.run[self.run["query_id"].isin(query_ids)] 424 # outer join if docs are from ir_datasets else only keep docs in run 425 how = "left" 426 if self._docs is None and self.add_docs_not_in_ranking: 427 how = "outer" 428 self.run = self.run.merge( 429 self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1), 430 on=["query_id", "doc_id"], 431 how=how, 432 ) 433 434 if self.sample_size != -1: 435 num_docs_per_query = self.run.groupby("query_id").transform("size") 436 self.run = self.run[num_docs_per_query >= self.sample_size] 437 438 self.run = self.run.sort_values(["query_id", "rank"]) 439 self.run_groups = self.run.groupby("query_id") 440 self.query_ids = list(self.run_groups.groups.keys()) 441 442 if self.depth != -1 and self.run["rank"].max() < self.depth: 443 warnings.warn("Depth is greater than the maximum rank in the run file.") 444 445 @staticmethod 446 def _load_csv(path: Path) -> pd.DataFrame: 447 return pd.read_csv( 448 path, 449 sep=r"\s+", 450 header=None, 451 names=RUN_HEADER, 452 usecols=[0, 2, 3, 4], 453 dtype={"query_id": str, "doc_id": str}, 454 quoting=csv.QUOTE_NONE, 455 na_filter=False, 456 ) 457 458 @staticmethod 459 def _load_parquet(path: Path) -> pd.DataFrame: 460 return pd.read_parquet(path).rename( 461 { 462 "qid": "query_id", 463 "docid": "doc_id", 464 "docno": "doc_id", 465 }, 466 axis=1, 467 ) 468 469 @staticmethod 470 def _load_json(path: Path) -> pd.DataFrame: 471 kwargs: Dict[str, Any] = {} 472 if ".jsonl" in path.suffixes: 473 kwargs["lines"] = True 474 kwargs["orient"] = "records" 475 run = pd.read_json( 476 path, 477 **kwargs, 478 dtype={ 479 "query_id": str, 480 "qid": str, 481 "doc_id": str, 482 "docid": str, 483 "docno": str, 484 }, 485 ) 486 return run 487 488 def _get_run_path(self) -> Path | None: 489 run_path = self.run_path 490 if run_path is None: 491 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs(): 492 raise ValueError("Run file or dataset with scoreddocs required.") 493 try: 494 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path() 495 except NotImplementedError: 496 pass 497 return run_path 498 499 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame: 500 run = run.rename( 501 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id"}, 502 axis=1, 503 ) 504 if "query" in run.columns: 505 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text") 506 run = run.drop("query", axis=1) 507 if "text" in run.columns: 508 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict() 509 run = run.drop("text", axis=1) 510 if self.depth != -1: 511 run = run[run["rank"] <= self.depth] 512 dtypes = {"rank": np.int32} 513 if "score" in run.columns: 514 dtypes["score"] = np.float32 515 run = run.astype(dtypes) 516 return run 517 518 def _load_run(self) -> pd.DataFrame: 519 520 suffix_load_map = { 521 ".tsv": self._load_csv, 522 ".run": self._load_csv, 523 ".csv": self._load_csv, 524 ".parquet": self._load_parquet, 525 ".json": self._load_json, 526 ".jsonl": self._load_json, 527 } 528 run = None 529 530 # try loading run from file 531 run_path = self._get_run_path() 532 if run_path is not None: 533 load_func = suffix_load_map.get(run_path.suffixes[0], None) 534 if load_func is not None: 535 try: 536 run = load_func(run_path) 537 except Exception: 538 pass 539 540 # try loading run from ir_datasets 541 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): 542 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) 543 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) 544 run = run.sort_values(["query_id", "rank"]) 545 546 if run is None: 547 raise ValueError("Invalid run file format.") 548 549 run = self._clean_run(run) 550 return run 551 552 @property 553 def qrels(self) -> pd.DataFrame | None: 554 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None. 555 556 :return: Qrels 557 :rtype: pd.DataFrame | None 558 """ 559 if self._qrels is not None: 560 return self._qrels 561 if self.run is not None and "relevance" in self.run: 562 qrels = self.run[["query_id", "doc_id", "relevance"]].copy() 563 if "iteration" in self.run: 564 qrels["iteration"] = self.run["iteration"] 565 else: 566 qrels["iteration"] = "0" 567 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore") 568 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"]) 569 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1) 570 qrels = qrels.droplevel(0, axis=1) 571 self._qrels = qrels 572 return self._qrels 573 return super().qrels 574 575 def __len__(self) -> int: 576 """Number of queries in the dataset. 577 578 :return: Number of queries 579 :rtype: int 580 """ 581 self._setup() 582 return len(self.query_ids) 583 584 def __getitem__(self, idx: int) -> RankSample: 585 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according 586 to the sampling strategy and sample size. 587 588 :param idx: Index of the query 589 :type idx: int 590 :raises ValueError: If the targets are not found in the run file 591 :return: Sampled query and documents 592 :rtype: RankSample 593 """ 594 self._setup() 595 query_id = str(self.query_ids[idx]) 596 group = self.run_groups.get_group(query_id).copy() 597 query = self.queries[query_id] 598 group = Sampler.sample(group, self.sample_size, self.sampling_strategy) 599 600 doc_ids = tuple(group["doc_id"]) 601 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 602 603 targets = None 604 if self.targets is not None: 605 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0) 606 if filtered.empty: 607 raise ValueError(f"targets `{self.targets}` not found in run file") 608 targets = torch.from_numpy(filtered.values) 609 if self.targets == "rank": 610 # invert ranks to be higher is better (necessary for loss functions) 611 targets = self.depth - targets + 1 612 if self.normalize_targets: 613 targets_min = targets.min() 614 targets_max = targets.max() 615 targets = (targets - targets_min) / (targets_max - targets_min) 616 qrels = None 617 if self.qrels is not None: 618 qrels = ( 619 self.qrels.loc[[query_id]] 620 .stack() 621 .rename("relevance") 622 .astype(int) 623 .reset_index() 624 .to_dict(orient="records") 625 ) 626 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
627 628
[docs] 629class TupleDataset(_IRDataset, IterableDataset):
[docs] 630 def __init__( 631 self, 632 tuples_dataset: str, 633 targets: Literal["order", "score"] = "order", 634 num_docs: int | None = None, 635 ) -> None: 636 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks. 637 638 :param tuples_dataset: Path to file containing tuples or valid ir_datasets id 639 :type tuples_dataset: str 640 :param targets: The data type to use as targets for a model during fine-tuning, defaults to "order" 641 :type targets: Literal["order", "score"], optional 642 :param num_docs: Maximum number of documents per query, defaults to None 643 :type num_docs: int | None, optional 644 """ 645 super().__init__(tuples_dataset) 646 super(_IRDataset, self).__init__() 647 self.targets = targets 648 self.num_docs = num_docs
649 650 def _parse_sample( 651 self, sample: ScoredDocTuple | GenericDocPair 652 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]: 653 if isinstance(sample, GenericDocPair): 654 if self.targets == "score": 655 raise ValueError("ScoredDocTuple required for score targets.") 656 targets = (1.0, 0.0) 657 doc_ids = (sample.doc_id_a, sample.doc_id_b) 658 elif isinstance(sample, ScoredDocTuple): 659 doc_ids = sample.doc_ids[: self.num_docs] 660 if self.targets == "score": 661 if sample.scores is None: 662 raise ValueError("tuples dataset does not contain scores") 663 targets = sample.scores 664 elif self.targets == "order": 665 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1)) 666 else: 667 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)") 668 targets = targets[: self.num_docs] 669 else: 670 raise ValueError("Invalid sample type.") 671 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids) 672 return doc_ids, docs, targets 673 674 def __iter__(self) -> Iterator[RankSample]: 675 """Iterates over tuples in the dataset. 676 677 :yield: A single tuple sample 678 :rtype: Iterator[RankSample] 679 """ 680 for sample in self.ir_dataset.docpairs_iter(): 681 query_id = sample.query_id 682 query = self.queries.loc[query_id] 683 doc_ids, docs, targets = self._parse_sample(sample) 684 if targets is not None: 685 targets = torch.tensor(targets) 686 yield RankSample(query_id, query, doc_ids, docs, targets)