1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from typing import TYPE_CHECKING, Literal, Tuple
5
6import torch
7
8if TYPE_CHECKING:
9 from ..base import LightningIROutput
10 from ..bi_encoder import BiEncoderOutput
11 from ..data import TrainBatch
12
13
[docs]
14class LossFunction(ABC):
15 @abstractmethod
16 def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor: ...
17
18 def process_scores(self, output: LightningIROutput) -> torch.Tensor:
19 if output.scores is None:
20 raise ValueError("Expected scores in LightningIROutput")
21 return output.scores
22
23 def process_targets(self, scores: torch.Tensor, batch: TrainBatch) -> torch.Tensor:
24 targets = batch.targets
25 if targets is None:
26 raise ValueError("Expected targets in TrainBatch")
27 if targets.ndim > scores.ndim:
28 return targets.max(-1).values
29 return targets
30
31
[docs]
32class ScoringLossFunction(LossFunction):
33 @abstractmethod
34 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: ...
35
36
[docs]
37class EmbeddingLossFunction(LossFunction):
38 @abstractmethod
39 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ...
40
41
[docs]
42class PairwiseLossFunction(ScoringLossFunction):
43 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]:
44 # positive items are items where label is greater than other label in sample
45 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
46
47
[docs]
48class ListwiseLossFunction(ScoringLossFunction):
49 pass
50
51
[docs]
52class MarginMSE(PairwiseLossFunction):
[docs]
53 def __init__(self, margin: float | Literal["scores"] = 1.0):
54 self.margin = margin
55
56 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
57 scores = self.process_scores(output)
58 targets = self.process_targets(scores, batch)
59 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
60 pos = scores[query_idcs, pos_idcs]
61 neg = scores[query_idcs, neg_idcs]
62 margin = pos - neg
63 if isinstance(self.margin, float):
64 target_margin = torch.tensor(self.margin, device=scores.device)
65 elif self.margin == "scores":
66 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs]
67 else:
68 raise ValueError("invalid margin type")
69 loss = torch.nn.functional.mse_loss(margin, target_margin)
70 return loss
71
72
[docs]
73class ConstantMarginMSE(MarginMSE):
[docs]
74 def __init__(self, margin: float = 1.0):
75 super().__init__(margin)
76
77
[docs]
78class SupervisedMarginMSE(MarginMSE):
[docs]
79 def __init__(self):
80 super().__init__("scores")
81
82
[docs]
83class RankNet(PairwiseLossFunction):
84 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
85 scores = self.process_scores(output)
86 targets = self.process_targets(scores, batch)
87 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
88 pos = scores[query_idcs, pos_idcs]
89 neg = scores[query_idcs, neg_idcs]
90 margin = pos - neg
91 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin))
92 return loss
93
94
[docs]
95class KLDivergence(ListwiseLossFunction):
96 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
97 scores = self.process_scores(output)
98 targets = self.process_targets(scores, batch)
99 scores = torch.nn.functional.log_softmax(scores, dim=-1)
100 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1)
101 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean")
102 return loss
103
104
[docs]
105class InfoNCE(ListwiseLossFunction):
106 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
107 scores = self.process_scores(output)
108 targets = self.process_targets(scores, batch)
109 targets = targets.argmax(dim=1)
110 loss = torch.nn.functional.cross_entropy(scores, targets)
111 return loss
112
113
[docs]
114class ApproxLossFunction(ListwiseLossFunction):
[docs]
115 def __init__(self, temperature: float = 1) -> None:
116 super().__init__()
117 self.temperature = temperature
118
119 @staticmethod
120 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor:
121 score_diff = scores[:, None] - scores[..., None]
122 normalized_score_diff = torch.sigmoid(score_diff / temperature)
123 # set diagonal to 0
124 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device))
125 approx_ranks = normalized_score_diff.sum(-1) + 1
126 return approx_ranks
127
128
[docs]
129class ApproxNDCG(ApproxLossFunction):
[docs]
130 def __init__(self, temperature: float = 1, scale_gains: bool = True):
131 super().__init__(temperature)
132 self.scale_gains = scale_gains
133
134 @staticmethod
135 def get_dcg(
136 ranks: torch.Tensor,
137 targets: torch.Tensor,
138 k: int | None = None,
139 scale_gains: bool = True,
140 ) -> torch.Tensor:
141 log_ranks = torch.log2(1 + ranks)
142 discounts = 1 / log_ranks
143 if scale_gains:
144 gains = 2**targets - 1
145 else:
146 gains = targets
147 dcgs = gains * discounts
148 if k is not None:
149 dcgs = dcgs.masked_fill(ranks > k, 0)
150 return dcgs.sum(dim=-1)
151
152 @staticmethod
153 def get_ndcg(
154 ranks: torch.Tensor,
155 targets: torch.Tensor,
156 k: int | None = None,
157 scale_gains: bool = True,
158 optimal_targets: torch.Tensor | None = None,
159 ) -> torch.Tensor:
160 targets = targets.clamp(min=0)
161 if optimal_targets is None:
162 optimal_targets = targets
163 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True))
164 optimal_ranks = optimal_ranks + 1
165 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains)
166 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains)
167 ndcg = dcg / (idcg.clamp(min=1e-12))
168 return ndcg
169
170 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
171 scores = self.process_scores(output)
172 scores = self.process_scores(output)
173 targets = self.process_targets(scores, batch)
174 approx_ranks = self.get_approx_ranks(scores, self.temperature)
175 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains)
176 loss = 1 - ndcg
177 return loss.mean()
178
179
[docs]
180class ApproxMRR(ApproxLossFunction):
[docs]
181 def __init__(self, temperature: float = 1):
182 super().__init__(temperature)
183
184 @staticmethod
185 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor:
186 targets = targets.clamp(None, 1)
187 reciprocal_ranks = 1 / ranks
188 mrr = reciprocal_ranks * targets
189 if k is not None:
190 mrr = mrr.masked_fill(ranks > k, 0)
191 mrr = mrr.max(dim=-1)[0]
192 return mrr
193
194 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
195 scores = self.process_scores(output)
196 targets = self.process_targets(scores, batch)
197 approx_ranks = self.get_approx_ranks(scores, self.temperature)
198 mrr = self.get_mrr(approx_ranks, targets, k=None)
199 loss = 1 - mrr
200 return loss.mean()
201
202
[docs]
203class ApproxRankMSE(ApproxLossFunction):
[docs]
204 def __init__(
205 self,
206 temperature: float = 1,
207 discount: Literal["log2", "reciprocal"] | None = None,
208 ):
209 super().__init__(temperature)
210 self.discount = discount
211
212 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
213 scores = self.process_scores(output)
214 targets = self.process_targets(scores, batch)
215 approx_ranks = self.get_approx_ranks(scores, self.temperature)
216 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1
217 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none")
218 if self.discount == "log2":
219 weight = 1 / torch.log2(ranks + 1)
220 elif self.discount == "reciprocal":
221 weight = 1 / ranks
222 else:
223 weight = 1
224 loss = loss * weight
225 loss = loss.mean()
226 return loss
227
228
[docs]
229class NeuralLossFunction(ListwiseLossFunction):
230 # TODO add neural loss functions
231
[docs]
232 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None:
233 super().__init__()
234 self.temperature = temperature
235 self.tol = tol
236 self.max_iter = max_iter
237
238 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor:
239 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py
240 scores = scores.unsqueeze(-1)
241 dim = scores.shape[1]
242 one = torch.ones((dim, 1), device=scores.device)
243
244 A_scores = torch.abs(scores - scores.permute(0, 2, 1))
245 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1)))
246 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1)
247 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0))
248
249 P_max = (C - B).permute(0, 2, 1)
250 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1)
251
252 P_hat = self.sinkhorn_scaling(P_hat)
253
254 return P_hat
255
256 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor:
257 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8
258 idx = 0
259 while True:
260 if (
261 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol
262 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol
263 ) or idx > self.max_iter:
264 break
265 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12)
266 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12)
267 idx += 1
268
269 return mat
270
271 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
272 permutation_matrix = self.neural_sort(scores)
273 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1)
274 return pred_sorted_targets
275
276
[docs]
277class InBatchLossFunction(LossFunction):
[docs]
278 def __init__(
279 self,
280 pos_sampling_technique: Literal["all", "first"] = "all",
281 neg_sampling_technique: Literal["all", "first", "all_and_non_first"] = "all",
282 max_num_neg_samples: int | None = None,
283 ):
284 super().__init__()
285 self.pos_sampling_technique = pos_sampling_technique
286 self.neg_sampling_technique = neg_sampling_technique
287 self.max_num_neg_samples = max_num_neg_samples
288 if self.neg_sampling_technique == "all_and_non_first" and self.pos_sampling_technique != "first":
289 raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")
290
291 def _get_pos_mask(
292 self,
293 num_queries: int,
294 num_docs: int,
295 max_idx: torch.Tensor,
296 min_idx: torch.Tensor,
297 output: LightningIROutput,
298 batch: TrainBatch,
299 ) -> torch.Tensor:
300 if self.pos_sampling_technique == "all":
301 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange(
302 num_queries * num_docs
303 )[None].less(max_idx)
304 elif self.pos_sampling_technique == "first":
305 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx)
306 else:
307 raise ValueError("invalid pos sampling technique")
308 return pos_mask
309
310 def _get_neg_mask(
311 self,
312 num_queries: int,
313 num_docs: int,
314 max_idx: torch.Tensor,
315 min_idx: torch.Tensor,
316 output: LightningIROutput,
317 batch: TrainBatch,
318 ) -> torch.Tensor:
319 if self.neg_sampling_technique == "all_and_non_first":
320 neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx)
321 elif self.neg_sampling_technique == "all":
322 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[
323 None
324 ].greater_equal(max_idx)
325 elif self.neg_sampling_technique == "first":
326 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange(
327 num_queries * num_docs
328 )[None].ne(min_idx)
329 else:
330 raise ValueError("invalid neg sampling technique")
331 return neg_mask
332
333 def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[torch.Tensor, torch.Tensor]:
334 if output.scores is None:
335 raise ValueError("Expected scores in LightningIROutput")
336 num_queries, num_docs = output.scores.shape
337 min_idx = torch.arange(num_queries)[:, None] * num_docs
338 max_idx = min_idx + num_docs
339 pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
340 neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
341 pos_idcs = pos_mask.nonzero(as_tuple=True)[1]
342 neg_idcs = neg_mask.nonzero(as_tuple=True)[1]
343 if self.max_num_neg_samples is not None:
344 neg_idcs = neg_idcs.view(num_queries, -1)
345 if neg_idcs.shape[-1] > 1:
346 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])]
347 neg_idcs = neg_idcs[:, : self.max_num_neg_samples]
348 neg_idcs = neg_idcs.reshape(-1)
349 return pos_idcs, neg_idcs
350
351
[docs]
352class ScoreBasedInBatchLossFunction(InBatchLossFunction):
353
[docs]
354 def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = None):
355 super().__init__(
356 pos_sampling_technique="first",
357 neg_sampling_technique="all_and_non_first",
358 max_num_neg_samples=max_num_neg_samples,
359 )
360 self.min_target_diff = min_target_diff
361
362 def _sort_mask(
363 self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch
364 ) -> torch.Tensor:
365 scores = self.process_scores(output)
366 targets = self.process_targets(scores, batch)
367 idcs = targets.argsort(descending=True).argsort().cpu()
368 idcs = idcs + torch.arange(num_queries)[:, None] * num_docs
369 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
370 return mask.scatter(1, block_idcs, mask.gather(1, idcs))
371
372 def _get_pos_mask(
373 self,
374 num_queries: int,
375 num_docs: int,
376 max_idx: torch.Tensor,
377 min_idx: torch.Tensor,
378 output: LightningIROutput,
379 batch: TrainBatch,
380 ) -> torch.Tensor:
381 pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
382 pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch)
383 return pos_mask
384
385 def _get_neg_mask(
386 self,
387 num_queries: int,
388 num_docs: int,
389 max_idx: torch.Tensor,
390 min_idx: torch.Tensor,
391 output: LightningIROutput,
392 batch: TrainBatch,
393 ) -> torch.Tensor:
394 neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
395 neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch)
396 scores = self.process_scores(output)
397 targets = self.process_targets(scores, batch).cpu()
398 max_score, _ = targets.max(dim=-1, keepdim=True)
399 score_diff = (max_score - targets).cpu()
400 score_mask = score_diff.ge(self.min_target_diff)
401 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
402 neg_mask = neg_mask.scatter(1, block_idcs, score_mask)
403 # num_neg_samples might be different between queries
404 num_neg_samples = neg_mask.sum(dim=1)
405 min_num_neg_samples = num_neg_samples.min()
406 additional_neg_samples = num_neg_samples - min_num_neg_samples
407 for query_idx, neg_samples in enumerate(additional_neg_samples):
408 neg_idcs = neg_mask[query_idx].nonzero().squeeze(1)
409 additional_neg_idcs = neg_idcs[torch.randperm(neg_idcs.shape[0])][:neg_samples]
410 assert neg_mask[query_idx, additional_neg_idcs].all().item()
411 neg_mask[query_idx, additional_neg_idcs] = False
412 assert neg_mask[query_idx].sum().eq(min_num_neg_samples).item()
413 return neg_mask
414
415
[docs]
416class InBatchCrossEntropy(InBatchLossFunction):
417 def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
418 scores = self.process_scores(output)
419 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
420 loss = torch.nn.functional.cross_entropy(scores, targets)
421 return loss
422
423
[docs]
424class ScoreBasedInBatchCrossEntropy(ScoreBasedInBatchLossFunction):
425
426 def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
427 scores = self.process_scores(output)
428 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
429 loss = torch.nn.functional.cross_entropy(scores, targets)
430 return loss
431
432
[docs]
433class RegularizationLossFunction(EmbeddingLossFunction):
[docs]
434 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None:
435 self.query_weight = query_weight
436 self.doc_weight = doc_weight
437
438 def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]:
439 query_embeddings = output.query_embeddings
440 doc_embeddings = output.doc_embeddings
441 if query_embeddings is None:
442 raise ValueError("Expected query_embeddings in BiEncoderOutput")
443 if doc_embeddings is None:
444 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
445 return query_embeddings.embeddings, doc_embeddings.embeddings
446
447
[docs]
448class L2Regularization(RegularizationLossFunction):
449 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
450 query_embeddings, doc_embeddings = self.process_embeddings(output)
451 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean()
452 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean()
453 loss = query_loss + doc_loss
454 return loss
455
456
[docs]
457class L1Regularization(RegularizationLossFunction):
458 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
459 query_embeddings, doc_embeddings = self.process_embeddings(output)
460 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean()
461 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean()
462 loss = query_loss + doc_loss
463 return loss
464
465
[docs]
466class FLOPSRegularization(RegularizationLossFunction):
467 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
468 query_embeddings, doc_embeddings = self.process_embeddings(output)
469 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2)
470 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2)
471 anti_zero = 1 / (torch.sum(query_embeddings) ** 2) + 1 / (torch.sum(doc_embeddings) ** 2)
472 loss = self.query_weight * query_loss + self.doc_weight * doc_loss + anti_zero
473 return loss