1from pathlib import Path
2
3import torch
4from huggingface_hub import hf_hub_download
5from transformers.modeling_utils import load_state_dict
6
7from ...base import LightningIRModelClassFactory
8from ...bi_encoder.model import BiEncoderEmbedding, ScoringFunction
9from ..col import ColModel
10from .config import XTRConfig
11
12
[docs]
13class XTRScoringFunction(ScoringFunction):
[docs]
14 def __init__(self, config: XTRConfig) -> None:
15 super().__init__(config)
16 self.config: XTRConfig
17
18 def compute_similarity(
19 self, query_embeddings: BiEncoderEmbedding, doc_embeddings: BiEncoderEmbedding
20 ) -> torch.Tensor:
21 similarity = super().compute_similarity(query_embeddings, doc_embeddings)
22
23 if self.training and self.xtr_token_retrieval_k is not None:
24 pass
25 # TODO implement simulated token retrieval
26
27 # if not torch.all(num_docs == num_docs[0]):
28 # raise ValueError("XTR token retrieval does not support variable number of documents.")
29 # query_embeddings = query_embeddings[:: num_docs[0]]
30 # doc_embeddings = doc_embeddings.view(1, 1, -1, doc_embeddings.shape[-1])
31 # ib_similarity = super().compute_similarity(
32 # query_embeddings,
33 # doc_embeddings,
34 # query_scoring_mask[:: num_docs[0]],
35 # doc_scoring_mask.view(1, -1),
36 # num_docs,
37 # )
38 # top_k_similarity = ib_similarity.topk(self.xtr_token_retrieval_k, dim=-1)
39 # cut_off_similarity = top_k_similarity.values[..., [-1]].repeat_interleave(num_docs, dim=0)
40 # if self.fill_strategy == "min":
41 # fill = cut_off_similarity.expand_as(similarity)[similarity < cut_off_similarity]
42 # elif self.fill_strategy == "zero":
43 # fill = 0
44 # similarity[similarity < cut_off_similarity] = fill
45 return similarity
46
47 # def aggregate(
48 # self,
49 # scores: torch.Tensor,
50 # mask: torch.Tensor,
51 # query_aggregation_function: Literal["max", "sum", "mean", "harmonic_mean"],
52 # ) -> torch.Tensor:
53 # if self.training and self.normalization == "Z":
54 # # Z-normalization
55 # mask = mask & (scores != 0)
56 # return super().aggregate(scores, mask, query_aggregation_function)
57
58
[docs]
59class XTRModel(ColModel):
60 config_class = XTRConfig
61
[docs]
62 def __init__(self, config: XTRConfig, *args, **kwargs) -> None:
63 super().__init__(config)
64 self.scoring_function = XTRScoringFunction(config)
65 self.config: XTRConfig
66
67 @classmethod
68 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "XTRModel":
69 try:
70 hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin")
71 except Exception:
72 return super().from_pretrained(model_name_or_path, *args, **kwargs)
73 finally:
74 return cls.from_xtr_checkpoint(model_name_or_path)
75
76 @classmethod
77 def from_xtr_checkpoint(cls, model_name_or_path: Path | str) -> "XTRModel":
78 from transformers import T5EncoderModel
79
80 cls = LightningIRModelClassFactory(XTRConfig).from_backbone_class(T5EncoderModel)
81 config = cls.config_class.from_pretrained(model_name_or_path)
82 config.update(
83 {
84 "name_or_path": str(model_name_or_path),
85 "similarity_function": "dot",
86 "query_aggregation_function": "sum",
87 "query_expansion": False,
88 "doc_expansion": False,
89 "doc_pooling_strategy": None,
90 "doc_mask_scoring_tokens": None,
91 "normalize": True,
92 "sparsification": None,
93 "add_marker_tokens": False,
94 "embedding_dim": 128,
95 "projection": "linear_no_bias",
96 }
97 )
98 state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="model.safetensors")
99 state_dict = load_state_dict(state_dict_path)
100 linear_state_dict_path = hf_hub_download(repo_id=str(model_name_or_path), filename="2_Dense/pytorch_model.bin")
101 linear_state_dict = load_state_dict(linear_state_dict_path)
102 linear_state_dict["projection.weight"] = linear_state_dict.pop("linear.weight")
103 state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
104 state_dict.update(linear_state_dict)
105 model = cls(config=config)
106 model.load_state_dict(state_dict)
107 return model