1"""
2Datasets for Lightning IR that data loading and sampling.
3
4This module defines several datasets that handle loading and sampling data for training and inference.
5"""
6
7import csv
8import warnings
9from itertools import islice
10from pathlib import Path
11from typing import Any, Dict, Iterator, Literal, Sequence, Tuple
12
13import ir_datasets
14import numpy as np
15import pandas as pd
16import torch
17from ir_datasets.formats import GenericDoc, GenericDocPair
18from torch.distributed import get_rank, get_world_size
19from torch.utils.data import Dataset, IterableDataset, get_worker_info
20
21from .data import DocSample, QuerySample, RankSample
22from .ir_datasets_utils import ScoredDocTuple
23
24RUN_HEADER = ["query_id", "q0", "doc_id", "rank", "score", "system"]
25
26
27class _IRDataset:
28 def __init__(self, dataset: str) -> None:
29 super().__init__()
30 self._dataset = dataset
31 self._queries = None
32 self._docs = None
33 self._qrels = None
34
35 @property
36 def dataset(self) -> str:
37 """Dataset name.
38
39 :return: Dataset name
40 :rtype: str
41 """
42 return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset)
43
44 @property
45 def dataset_id(self) -> str:
46 """Dataset id.
47
48 :return: Dataset id
49 :rtype: str
50 """
51 if self.ir_dataset is None:
52 return self.dataset
53 return self.ir_dataset.dataset_id()
54
55 @property
56 def docs_dataset_id(self) -> str:
57 """ID of the dataset containing the documents.
58
59 :return: Document dataset id
60 :rtype: str
61 """
62 return ir_datasets.docs_parent_id(self.dataset_id)
63
64 @property
65 def ir_dataset(self) -> ir_datasets.Dataset | None:
66 """Instance of ir_datasets.Dataset.
67
68 :return: ir_datasets dataset
69 :rtype: ir_datasets.Dataset | None
70 """
71 try:
72 return ir_datasets.load(self.dataset)
73 except KeyError:
74 return None
75
76 @property
77 def DASHED_DATASET_MAP(self) -> Dict[str, str]:
78 """Map of dataset names with dashes to dataset names with slashes.
79
80 :return: Dataset map
81 :rtype: Dict[str, str]
82 """
83 return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
84
85 @property
86 def queries(self) -> pd.Series:
87 """Queries in the dataset.
88
89 :raises ValueError: If no queries are found in the dataset
90 :return: Queries
91 :rtype: pd.Series
92 """
93 if self._queries is None:
94 if self.ir_dataset is None:
95 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
96 queries_iter = self.ir_dataset.queries_iter()
97 self._queries = pd.Series(
98 {query.query_id: query.default_text() for query in queries_iter},
99 name="text",
100 )
101 self._queries.index.name = "query_id"
102 return self._queries
103
104 @property
105 def docs(self) -> ir_datasets.indices.Docstore | Dict[str, GenericDoc]:
106 """Documents in the dataset.
107
108 :raises ValueError: If no documents are found in the dataset
109 :return: Documents
110 :rtype: ir_datasets.indices.Docstore | Dict[str, GenericDoc]
111 """
112 if self._docs is None:
113 if self.ir_dataset is None:
114 raise ValueError(f"Unable to find dataset {self.dataset} in ir-datasets")
115 self._docs = self.ir_dataset.docs_store()
116 return self._docs
117
118 @property
119 def qrels(self) -> pd.DataFrame | None:
120 """Qrels in the dataset.
121
122 :return: Qrels
123 :rtype: pd.DataFrame | None
124 """
125 if self._qrels is not None:
126 return self._qrels
127 if self.ir_dataset is None:
128 return None
129 qrels = pd.DataFrame(self.ir_dataset.qrels_iter()).rename({"subtopic_id": "iteration"}, axis=1)
130 if "iteration" not in qrels.columns:
131 qrels["iteration"] = 0
132 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
133 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
134 qrels = qrels.droplevel(0, axis=1)
135 self._qrels = qrels
136 return self._qrels
137
138
139class _DataParallelIterableDataset(IterableDataset):
140 # https://github.com/Lightning-AI/pytorch-lightning/issues/15734
141 def __init__(self) -> None:
142 super().__init__()
143 # TODO add support for multi-gpu and multi-worker inference; currently
144 # doesn't work
145 worker_info = get_worker_info()
146 num_workers = worker_info.num_workers if worker_info is not None else 1
147 worker_id = worker_info.id if worker_info is not None else 0
148
149 try:
150 world_size = get_world_size()
151 process_rank = get_rank()
152 except (RuntimeError, ValueError):
153 world_size = 1
154 process_rank = 0
155
156 self.num_replicas = num_workers * world_size
157 self.rank = process_rank * num_workers + worker_id
158
159
[docs]
160class QueryDataset(_IRDataset, _DataParallelIterableDataset):
[docs]
161 def __init__(self, query_dataset: str, num_queries: int | None = None) -> None:
162 """Dataset containing queries.
163
164 :param query_dataset: Path to file containing queries or valid ir_datasets id
165 :type query_dataset: str
166 :param num_queries: Number of queries in dataset. If None, the number of queries will attempted to be inferred,
167 defaults to None
168 :type num_queries: int | None, optional
169 """
170 super().__init__(query_dataset)
171 super(_IRDataset, self).__init__()
172 self.num_queries = num_queries
173
174 def __len__(self) -> int:
175 """Number of queries in the dataset.
176
177 :return: Number of queries
178 :rtype: int
179 """
180 # TODO fix len for multi-gpu and multi-worker inference
181 return self.num_queries or self.ir_dataset.queries_count()
182
183 def __iter__(self) -> Iterator[QuerySample]:
184 """Iterate over queries in the dataset.
185
186 :yield: Query sample
187 :rtype: Iterator[QuerySample]
188 """
189 start = self.rank
190 stop = self.num_queries
191 step = self.num_replicas
192 for sample in islice(self.ir_dataset.queries_iter(), start, stop, step):
193 query_sample = QuerySample.from_ir_dataset_sample(sample)
194 if self.qrels is not None:
195 qrels = (
196 self.qrels.loc[[query_sample.query_id]]
197 .stack()
198 .rename("relevance")
199 .astype(int)
200 .reset_index()
201 .to_dict(orient="records")
202 )
203 query_sample.qrels = qrels
204 yield query_sample
205
206
[docs]
207class DocDataset(_IRDataset, _DataParallelIterableDataset):
[docs]
208 def __init__(self, doc_dataset: str, num_docs: int | None = None, text_fields: Sequence[str] | None = None) -> None:
209 """Dataset containing documents.
210
211 :param doc_dataset: Path to file containing documents or valid ir_datasets id
212 :type doc_dataset: str
213 :param num_docs: Number of documents in dataset. If None, the number of documents will attempted to be inferred,
214 defaults to None
215 :type num_docs: int | None, optional
216 :param text_fields: Fields to parse the document text from, defaults to None
217 :type text_fields: Sequence[str] | None, optional
218 """
219 super().__init__(doc_dataset)
220 super(_IRDataset, self).__init__()
221 self.num_docs = num_docs
222 self.text_fields = text_fields
223
224 def __len__(self) -> int:
225 """Number of documents in the dataset.
226
227 :raises ValueError: If no `num_docs` was not provided in the constructor and the number of documents cannot
228 be inferred
229 :return: Number of documents
230 :rtype: int
231 """
232 # TODO fix len for multi-gpu and multi-worker inference
233 num_docs = self.num_docs or self.ir_dataset.docs_count()
234 if num_docs is None:
235 raise ValueError("Unable to determine number of documents.")
236 return num_docs
237
238 def __iter__(self) -> Iterator[DocSample]:
239 """Iterate over documents in the dataset.
240
241 :yield: Doc sample
242 :rtype: Iterator[DocSample]
243 """
244 start = self.rank
245 stop = self.num_docs
246 step = self.num_replicas
247 for sample in islice(self.ir_dataset.docs_iter(), start, stop, step):
248 yield DocSample.from_ir_dataset_sample(sample, self.text_fields)
249
250
[docs]
251class Sampler:
252 """Helper class for sampling subsets of documents from a ranked list."""
253
[docs]
254 @staticmethod
255 def single_relevant(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
256 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
257 are non-relevant.
258
259 :param documents: Ranked list of documents
260 :type documents: pd.DataFrame
261 :param sample_size: Number of documents to sample
262 :type sample_size: int
263 :return: Sampled documents
264 :rtype: pd.DataFrame
265 """
266 relevance = documents.filter(like="relevance").max(axis=1).fillna(0)
267 relevant = documents.loc[relevance.gt(0)].sample(1)
268 non_relevant_bool = relevance.eq(0) & ~documents["rank"].isna()
269 num_non_relevant = non_relevant_bool.sum()
270 sample_non_relevant = min(sample_size - 1, num_non_relevant)
271 non_relevant = documents.loc[non_relevant_bool].sample(sample_non_relevant)
272 return pd.concat([relevant, non_relevant])
273
[docs]
274 @staticmethod
275 def top(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
276 """Sampling strategy to randomly sample a single relevant document. The remaining ``sample_size - 1``
277 are non-relevant.
278
279 :param documents: Ranked list of documents
280 :type documents: pd.DataFrame
281 :param sample_size: Number of documents to sample
282 :type sample_size: int
283 :return: Sampled documents
284 :rtype: pd.DataFrame
285 """
286 return documents.head(sample_size)
287
[docs]
288 @staticmethod
289 def top_and_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
290 """Sampling strategy to randomly sample half the ``sample_size`` documents from the top of the ranking and the
291 other half randomly.
292
293 :param documents: Ranked list of documents
294 :type documents: pd.DataFrame
295 :param sample_size: Number of documents to sample
296 :type sample_size: int
297 :return: Sampled documents
298 :rtype: pd.DataFrame
299 """
300 top_size = sample_size // 2
301 random_size = sample_size - top_size
302 top = documents.head(top_size)
303 random = documents.iloc[top_size:].sample(random_size)
304 return pd.concat([top, random])
305
[docs]
306 @staticmethod
307 def random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
308 """Sampling strategy to randomly sample ``sample_size`` documents.
309
310 :param documents: Ranked list of documents
311 :type documents: pd.DataFrame
312 :param sample_size: Number of documents to sample
313 :type sample_size: int
314 :return: Sampled documents
315 :rtype: pd.DataFrame
316 """
317 return documents.sample(sample_size)
318
[docs]
319 @staticmethod
320 def log_random(documents: pd.DataFrame, sample_size: int) -> pd.DataFrame:
321 """Sampling strategy to randomly sample documents with a higher probability to sample documents from the top of
322 the ranking.
323
324 :param documents: Ranked list of documents
325 :type documents: pd.DataFrame
326 :param sample_size: Number of documents to sample
327 :type sample_size: int
328 :return: Sampled documents
329 :rtype: pd.DataFrame
330 """
331 weights = 1 / np.log1p(documents["rank"])
332 weights[weights.isna()] = weights.min()
333 return documents.sample(sample_size, weights=weights)
334
[docs]
335 @staticmethod
336 def sample(
337 df: pd.DataFrame,
338 sample_size: int,
339 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
340 ) -> pd.DataFrame:
341 """
342 Samples a subset of documents from a ranked list given a sampling_strategy.
343
344 :param documents: Ranked list of documents
345 :type documents: pd.DataFrame
346 :param sample_size: Number of documents to sample
347 :type sample_size: int
348 :return: Sampled documents
349 :rtype: pd.DataFrame
350 """
351 if sample_size == -1:
352 return df
353 if hasattr(Sampler, sampling_strategy):
354 return getattr(Sampler, sampling_strategy)(df, sample_size)
355 raise ValueError("Invalid sampling strategy.")
356
357
[docs]
358class RunDataset(_IRDataset, Dataset):
[docs]
359 def __init__(
360 self,
361 run_path_or_id: Path | str,
362 depth: int = -1,
363 sample_size: int = -1,
364 sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
365 targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
366 normalize_targets: bool = False,
367 add_docs_not_in_ranking: bool = False,
368 ) -> None:
369 """Dataset containing a list of queries with a ranked list of documents per query. Subsets of the ranked list
370 can be sampled using different sampling strategies.
371
372 :param run_path_or_id: Path to a run file or valid ir_datasets id
373 :type run_path_or_id: Path | str
374 :param depth: Depth at which to cut off the ranking. If -1 the full ranking is kept, defaults to -1
375 :type depth: int, optional
376 :param sample_size: The number of documents to sample per query, defaults to -1
377 :type sample_size: int, optional
378 :param sampling_strategy: The sample strategy to use to sample documents, defaults to "top"
379 :type sampling_strategy: Literal['single_relevant', 'top', 'random', 'log_random', 'top_and_random'], optional
380 :param targets: The data type to use as targets for a model during fine-tuning. If relevance the relevance
381 judgements are parsed from qrels, defaults to None
382 :type targets: Literal['relevance', 'subtopic_relevance', 'rank', 'score'] | None, optional
383 :param normalize_targets: Whether to normalize the targets between 0 and 1, defaults to False
384 :type normalize_targets: bool, optional
385 :param add_docs_not_in_ranking: Whether to add relevant to a sample that are in the qrels but not in the
386 ranking, defaults to False
387 :type add_docs_not_in_ranking: bool, optional
388 """
389 self.run_path = None
390 if Path(run_path_or_id).is_file():
391 self.run_path = Path(run_path_or_id)
392 dataset = self.run_path.name.split(".")[0].split("__")[-1]
393 else:
394 dataset = str(run_path_or_id)
395 super().__init__(dataset)
396 self.depth = depth
397 self.sample_size = sample_size
398 self.sampling_strategy = sampling_strategy
399 self.targets = targets
400 self.normalize_targets = normalize_targets
401 self.add_docs_not_in_ranking = add_docs_not_in_ranking
402
403 if self.sampling_strategy == "top" and self.sample_size > self.depth:
404 warnings.warn(
405 "Sample size is greater than depth and top sampling strategy is used. "
406 "This can cause documents to be sampled that are not contained "
407 "in the run file, but that are present in the qrels."
408 )
409
410 self.run: pd.DataFrame | None = None
411
412 def _setup(self):
413 if self.run is not None:
414 return
415 self.run = self._load_run()
416 self.run = self.run.drop_duplicates(["query_id", "doc_id"])
417
418 if self.qrels is not None:
419 run_query_ids = pd.Index(self.run["query_id"].drop_duplicates())
420 qrels_query_ids = self.qrels.index.get_level_values("query_id").unique()
421 query_ids = run_query_ids.intersection(qrels_query_ids)
422 if len(run_query_ids.difference(qrels_query_ids)):
423 self.run = self.run[self.run["query_id"].isin(query_ids)]
424 # outer join if docs are from ir_datasets else only keep docs in run
425 how = "left"
426 if self._docs is None and self.add_docs_not_in_ranking:
427 how = "outer"
428 self.run = self.run.merge(
429 self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1),
430 on=["query_id", "doc_id"],
431 how=how,
432 )
433
434 if self.sample_size != -1:
435 num_docs_per_query = self.run.groupby("query_id").transform("size")
436 self.run = self.run[num_docs_per_query >= self.sample_size]
437
438 self.run = self.run.sort_values(["query_id", "rank"])
439 self.run_groups = self.run.groupby("query_id")
440 self.query_ids = list(self.run_groups.groups.keys())
441
442 if self.depth != -1 and self.run["rank"].max() < self.depth:
443 warnings.warn("Depth is greater than the maximum rank in the run file.")
444
445 @staticmethod
446 def _load_csv(path: Path) -> pd.DataFrame:
447 return pd.read_csv(
448 path,
449 sep=r"\s+",
450 header=None,
451 names=RUN_HEADER,
452 usecols=[0, 2, 3, 4],
453 dtype={"query_id": str, "doc_id": str},
454 quoting=csv.QUOTE_NONE,
455 na_filter=False,
456 )
457
458 @staticmethod
459 def _load_parquet(path: Path) -> pd.DataFrame:
460 return pd.read_parquet(path).rename(
461 {
462 "qid": "query_id",
463 "docid": "doc_id",
464 "docno": "doc_id",
465 },
466 axis=1,
467 )
468
469 @staticmethod
470 def _load_json(path: Path) -> pd.DataFrame:
471 kwargs: Dict[str, Any] = {}
472 if ".jsonl" in path.suffixes:
473 kwargs["lines"] = True
474 kwargs["orient"] = "records"
475 run = pd.read_json(
476 path,
477 **kwargs,
478 dtype={
479 "query_id": str,
480 "qid": str,
481 "doc_id": str,
482 "docid": str,
483 "docno": str,
484 },
485 )
486 return run
487
488 def _get_run_path(self) -> Path | None:
489 run_path = self.run_path
490 if run_path is None:
491 if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs():
492 raise ValueError("Run file or dataset with scoreddocs required.")
493 try:
494 run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path()
495 except NotImplementedError:
496 pass
497 return run_path
498
499 def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame:
500 run = run.rename(
501 {"qid": "query_id", "docid": "doc_id", "docno": "doc_id"},
502 axis=1,
503 )
504 if "query" in run.columns:
505 self._queries = run.drop_duplicates("query_id").set_index("query_id")["query"].rename("text")
506 run = run.drop("query", axis=1)
507 if "text" in run.columns:
508 self._docs = run.set_index("doc_id")["text"].map(lambda x: GenericDoc("", x)).to_dict()
509 run = run.drop("text", axis=1)
510 if self.depth != -1:
511 run = run[run["rank"] <= self.depth]
512 dtypes = {"rank": np.int32}
513 if "score" in run.columns:
514 dtypes["score"] = np.float32
515 run = run.astype(dtypes)
516 return run
517
518 def _load_run(self) -> pd.DataFrame:
519
520 suffix_load_map = {
521 ".tsv": self._load_csv,
522 ".run": self._load_csv,
523 ".csv": self._load_csv,
524 ".parquet": self._load_parquet,
525 ".json": self._load_json,
526 ".jsonl": self._load_json,
527 }
528 run = None
529
530 # try loading run from file
531 run_path = self._get_run_path()
532 if run_path is not None:
533 load_func = suffix_load_map.get(run_path.suffixes[0], None)
534 if load_func is not None:
535 try:
536 run = load_func(run_path)
537 except Exception:
538 pass
539
540 # try loading run from ir_datasets
541 if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
542 run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
543 run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
544 run = run.sort_values(["query_id", "rank"])
545
546 if run is None:
547 raise ValueError("Invalid run file format.")
548
549 run = self._clean_run(run)
550 return run
551
552 @property
553 def qrels(self) -> pd.DataFrame | None:
554 """The qrels in the dataset. If the dataset does not contain qrels, the qrels are None.
555
556 :return: Qrels
557 :rtype: pd.DataFrame | None
558 """
559 if self._qrels is not None:
560 return self._qrels
561 if self.run is not None and "relevance" in self.run:
562 qrels = self.run[["query_id", "doc_id", "relevance"]].copy()
563 if "iteration" in self.run:
564 qrels["iteration"] = self.run["iteration"]
565 else:
566 qrels["iteration"] = "0"
567 self.run = self.run.drop(["relevance", "iteration"], axis=1, errors="ignore")
568 qrels = qrels.drop_duplicates(["query_id", "doc_id", "iteration"])
569 qrels = qrels.set_index(["query_id", "doc_id", "iteration"]).unstack(level=-1)
570 qrels = qrels.droplevel(0, axis=1)
571 self._qrels = qrels
572 return self._qrels
573 return super().qrels
574
575 def __len__(self) -> int:
576 """Number of queries in the dataset.
577
578 :return: Number of queries
579 :rtype: int
580 """
581 self._setup()
582 return len(self.query_ids)
583
584 def __getitem__(self, idx: int) -> RankSample:
585 """Samples a single query and corresponding ranked documents from the run. The documents are sampled according
586 to the sampling strategy and sample size.
587
588 :param idx: Index of the query
589 :type idx: int
590 :raises ValueError: If the targets are not found in the run file
591 :return: Sampled query and documents
592 :rtype: RankSample
593 """
594 self._setup()
595 query_id = str(self.query_ids[idx])
596 group = self.run_groups.get_group(query_id).copy()
597 query = self.queries[query_id]
598 group = Sampler.sample(group, self.sample_size, self.sampling_strategy)
599
600 doc_ids = tuple(group["doc_id"])
601 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
602
603 targets = None
604 if self.targets is not None:
605 filtered = group.set_index("doc_id").loc[list(doc_ids)].filter(like=self.targets).fillna(0)
606 if filtered.empty:
607 raise ValueError(f"targets `{self.targets}` not found in run file")
608 targets = torch.from_numpy(filtered.values)
609 if self.targets == "rank":
610 # invert ranks to be higher is better (necessary for loss functions)
611 targets = self.depth - targets + 1
612 if self.normalize_targets:
613 targets_min = targets.min()
614 targets_max = targets.max()
615 targets = (targets - targets_min) / (targets_max - targets_min)
616 qrels = None
617 if self.qrels is not None:
618 qrels = (
619 self.qrels.loc[[query_id]]
620 .stack()
621 .rename("relevance")
622 .astype(int)
623 .reset_index()
624 .to_dict(orient="records")
625 )
626 return RankSample(query_id, query, doc_ids, docs, targets, qrels)
627
628
[docs]
629class TupleDataset(_IRDataset, IterableDataset):
[docs]
630 def __init__(
631 self,
632 tuples_dataset: str,
633 targets: Literal["order", "score"] = "order",
634 num_docs: int | None = None,
635 ) -> None:
636 """Dataset containing tuples of a query and n-documents. Used for fine-tuning models on ranking tasks.
637
638 :param tuples_dataset: Path to file containing tuples or valid ir_datasets id
639 :type tuples_dataset: str
640 :param targets: The data type to use as targets for a model during fine-tuning, defaults to "order"
641 :type targets: Literal["order", "score"], optional
642 :param num_docs: Maximum number of documents per query, defaults to None
643 :type num_docs: int | None, optional
644 """
645 super().__init__(tuples_dataset)
646 super(_IRDataset, self).__init__()
647 self.targets = targets
648 self.num_docs = num_docs
649
650 def _parse_sample(
651 self, sample: ScoredDocTuple | GenericDocPair
652 ) -> Tuple[Tuple[str, ...], Tuple[str, ...], Tuple[float, ...] | None]:
653 if isinstance(sample, GenericDocPair):
654 if self.targets == "score":
655 raise ValueError("ScoredDocTuple required for score targets.")
656 targets = (1.0, 0.0)
657 doc_ids = (sample.doc_id_a, sample.doc_id_b)
658 elif isinstance(sample, ScoredDocTuple):
659 doc_ids = sample.doc_ids[: self.num_docs]
660 if self.targets == "score":
661 if sample.scores is None:
662 raise ValueError("tuples dataset does not contain scores")
663 targets = sample.scores
664 elif self.targets == "order":
665 targets = tuple([1.0] + [0.0] * (sample.num_docs - 1))
666 else:
667 raise ValueError(f"invalid value for targets, got {self.targets}, " "expected one of (order, score)")
668 targets = targets[: self.num_docs]
669 else:
670 raise ValueError("Invalid sample type.")
671 docs = tuple(self.docs.get(doc_id).default_text() for doc_id in doc_ids)
672 return doc_ids, docs, targets
673
674 def __iter__(self) -> Iterator[RankSample]:
675 """Iterates over tuples in the dataset.
676
677 :yield: A single tuple sample
678 :rtype: Iterator[RankSample]
679 """
680 for sample in self.ir_dataset.docpairs_iter():
681 query_id = sample.query_id
682 query = self.queries.loc[query_id]
683 doc_ids, docs, targets = self._parse_sample(sample)
684 if targets is not None:
685 targets = torch.tensor(targets)
686 yield RankSample(query_id, query, doc_ids, docs, targets)