Source code for lightning_ir.loss.loss

  1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from typing import TYPE_CHECKING, Literal, Tuple
  5
  6import torch
  7
  8if TYPE_CHECKING:
  9    from ..base import LightningIROutput
 10    from ..bi_encoder import BiEncoderOutput
 11    from ..data import TrainBatch
 12
 13
[docs] 14class LossFunction(ABC): 15 @abstractmethod 16 def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor: ... 17 18 def process_scores(self, output: LightningIROutput) -> torch.Tensor: 19 if output.scores is None: 20 raise ValueError("Expected scores in LightningIROutput") 21 return output.scores 22 23 def process_targets(self, scores: torch.Tensor, batch: TrainBatch) -> torch.Tensor: 24 targets = batch.targets 25 if targets is None: 26 raise ValueError("Expected targets in TrainBatch") 27 if targets.ndim > scores.ndim: 28 return targets.max(-1).values 29 return targets
30 31
[docs] 32class ScoringLossFunction(LossFunction): 33 @abstractmethod 34 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: ...
35 36
[docs] 37class EmbeddingLossFunction(LossFunction): 38 @abstractmethod 39 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ...
40 41
[docs] 42class PairwiseLossFunction(ScoringLossFunction): 43 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]: 44 # positive items are items where label is greater than other label in sample 45 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
46 47
[docs] 48class ListwiseLossFunction(ScoringLossFunction): 49 pass
50 51
[docs] 52class MarginMSE(PairwiseLossFunction):
[docs] 53 def __init__(self, margin: float | Literal["scores"] = 1.0): 54 self.margin = margin
55 56 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 57 scores = self.process_scores(output) 58 targets = self.process_targets(scores, batch) 59 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 60 pos = scores[query_idcs, pos_idcs] 61 neg = scores[query_idcs, neg_idcs] 62 margin = pos - neg 63 if isinstance(self.margin, float): 64 target_margin = torch.tensor(self.margin, device=scores.device) 65 elif self.margin == "scores": 66 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs] 67 else: 68 raise ValueError("invalid margin type") 69 loss = torch.nn.functional.mse_loss(margin, target_margin) 70 return loss
71 72
[docs] 73class ConstantMarginMSE(MarginMSE):
[docs] 74 def __init__(self, margin: float = 1.0): 75 super().__init__(margin)
76 77
[docs] 78class SupervisedMarginMSE(MarginMSE):
[docs] 79 def __init__(self): 80 super().__init__("scores")
81 82
[docs] 83class RankNet(PairwiseLossFunction): 84 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 85 scores = self.process_scores(output) 86 targets = self.process_targets(scores, batch) 87 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 88 pos = scores[query_idcs, pos_idcs] 89 neg = scores[query_idcs, neg_idcs] 90 margin = pos - neg 91 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin)) 92 return loss
93 94
[docs] 95class KLDivergence(ListwiseLossFunction): 96 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 97 scores = self.process_scores(output) 98 targets = self.process_targets(scores, batch) 99 scores = torch.nn.functional.log_softmax(scores, dim=-1) 100 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1) 101 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean") 102 return loss
103 104
[docs] 105class InfoNCE(ListwiseLossFunction): 106 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 107 scores = self.process_scores(output) 108 targets = self.process_targets(scores, batch) 109 targets = targets.argmax(dim=1) 110 loss = torch.nn.functional.cross_entropy(scores, targets) 111 return loss
112 113
[docs] 114class ApproxLossFunction(ListwiseLossFunction):
[docs] 115 def __init__(self, temperature: float = 1) -> None: 116 super().__init__() 117 self.temperature = temperature
118 119 @staticmethod 120 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor: 121 score_diff = scores[:, None] - scores[..., None] 122 normalized_score_diff = torch.sigmoid(score_diff / temperature) 123 # set diagonal to 0 124 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device)) 125 approx_ranks = normalized_score_diff.sum(-1) + 1 126 return approx_ranks
127 128
[docs] 129class ApproxNDCG(ApproxLossFunction):
[docs] 130 def __init__(self, temperature: float = 1, scale_gains: bool = True): 131 super().__init__(temperature) 132 self.scale_gains = scale_gains
133 134 @staticmethod 135 def get_dcg( 136 ranks: torch.Tensor, 137 targets: torch.Tensor, 138 k: int | None = None, 139 scale_gains: bool = True, 140 ) -> torch.Tensor: 141 log_ranks = torch.log2(1 + ranks) 142 discounts = 1 / log_ranks 143 if scale_gains: 144 gains = 2**targets - 1 145 else: 146 gains = targets 147 dcgs = gains * discounts 148 if k is not None: 149 dcgs = dcgs.masked_fill(ranks > k, 0) 150 return dcgs.sum(dim=-1) 151 152 @staticmethod 153 def get_ndcg( 154 ranks: torch.Tensor, 155 targets: torch.Tensor, 156 k: int | None = None, 157 scale_gains: bool = True, 158 optimal_targets: torch.Tensor | None = None, 159 ) -> torch.Tensor: 160 targets = targets.clamp(min=0) 161 if optimal_targets is None: 162 optimal_targets = targets 163 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True)) 164 optimal_ranks = optimal_ranks + 1 165 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains) 166 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains) 167 ndcg = dcg / (idcg.clamp(min=1e-12)) 168 return ndcg 169 170 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 171 scores = self.process_scores(output) 172 scores = self.process_scores(output) 173 targets = self.process_targets(scores, batch) 174 approx_ranks = self.get_approx_ranks(scores, self.temperature) 175 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains) 176 loss = 1 - ndcg 177 return loss.mean()
178 179
[docs] 180class ApproxMRR(ApproxLossFunction):
[docs] 181 def __init__(self, temperature: float = 1): 182 super().__init__(temperature)
183 184 @staticmethod 185 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor: 186 targets = targets.clamp(None, 1) 187 reciprocal_ranks = 1 / ranks 188 mrr = reciprocal_ranks * targets 189 if k is not None: 190 mrr = mrr.masked_fill(ranks > k, 0) 191 mrr = mrr.max(dim=-1)[0] 192 return mrr 193 194 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 195 scores = self.process_scores(output) 196 targets = self.process_targets(scores, batch) 197 approx_ranks = self.get_approx_ranks(scores, self.temperature) 198 mrr = self.get_mrr(approx_ranks, targets, k=None) 199 loss = 1 - mrr 200 return loss.mean()
201 202
[docs] 203class ApproxRankMSE(ApproxLossFunction):
[docs] 204 def __init__( 205 self, 206 temperature: float = 1, 207 discount: Literal["log2", "reciprocal"] | None = None, 208 ): 209 super().__init__(temperature) 210 self.discount = discount
211 212 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 213 scores = self.process_scores(output) 214 targets = self.process_targets(scores, batch) 215 approx_ranks = self.get_approx_ranks(scores, self.temperature) 216 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1 217 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none") 218 if self.discount == "log2": 219 weight = 1 / torch.log2(ranks + 1) 220 elif self.discount == "reciprocal": 221 weight = 1 / ranks 222 else: 223 weight = 1 224 loss = loss * weight 225 loss = loss.mean() 226 return loss
227 228
[docs] 229class NeuralLossFunction(ListwiseLossFunction): 230 # TODO add neural loss functions 231
[docs] 232 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None: 233 super().__init__() 234 self.temperature = temperature 235 self.tol = tol 236 self.max_iter = max_iter
237 238 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor: 239 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py 240 scores = scores.unsqueeze(-1) 241 dim = scores.shape[1] 242 one = torch.ones((dim, 1), device=scores.device) 243 244 A_scores = torch.abs(scores - scores.permute(0, 2, 1)) 245 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1))) 246 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1) 247 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0)) 248 249 P_max = (C - B).permute(0, 2, 1) 250 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1) 251 252 P_hat = self.sinkhorn_scaling(P_hat) 253 254 return P_hat 255 256 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor: 257 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8 258 idx = 0 259 while True: 260 if ( 261 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol 262 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol 263 ) or idx > self.max_iter: 264 break 265 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12) 266 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12) 267 idx += 1 268 269 return mat 270 271 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 272 permutation_matrix = self.neural_sort(scores) 273 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1) 274 return pred_sorted_targets
275 276
[docs] 277class InBatchLossFunction(LossFunction):
[docs] 278 def __init__( 279 self, 280 pos_sampling_technique: Literal["all", "first"] = "all", 281 neg_sampling_technique: Literal["all", "first", "all_and_non_first"] = "all", 282 max_num_neg_samples: int | None = None, 283 ): 284 super().__init__() 285 self.pos_sampling_technique = pos_sampling_technique 286 self.neg_sampling_technique = neg_sampling_technique 287 self.max_num_neg_samples = max_num_neg_samples 288 if self.neg_sampling_technique == "all_and_non_first" and self.pos_sampling_technique != "first": 289 raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")
290 291 def _get_pos_mask( 292 self, 293 num_queries: int, 294 num_docs: int, 295 max_idx: torch.Tensor, 296 min_idx: torch.Tensor, 297 output: LightningIROutput, 298 batch: TrainBatch, 299 ) -> torch.Tensor: 300 if self.pos_sampling_technique == "all": 301 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange( 302 num_queries * num_docs 303 )[None].less(max_idx) 304 elif self.pos_sampling_technique == "first": 305 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx) 306 else: 307 raise ValueError("invalid pos sampling technique") 308 return pos_mask 309 310 def _get_neg_mask( 311 self, 312 num_queries: int, 313 num_docs: int, 314 max_idx: torch.Tensor, 315 min_idx: torch.Tensor, 316 output: LightningIROutput, 317 batch: TrainBatch, 318 ) -> torch.Tensor: 319 if self.neg_sampling_technique == "all_and_non_first": 320 neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx) 321 elif self.neg_sampling_technique == "all": 322 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[ 323 None 324 ].greater_equal(max_idx) 325 elif self.neg_sampling_technique == "first": 326 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange( 327 num_queries * num_docs 328 )[None].ne(min_idx) 329 else: 330 raise ValueError("invalid neg sampling technique") 331 return neg_mask 332 333 def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[torch.Tensor, torch.Tensor]: 334 if output.scores is None: 335 raise ValueError("Expected scores in LightningIROutput") 336 num_queries, num_docs = output.scores.shape 337 min_idx = torch.arange(num_queries)[:, None] * num_docs 338 max_idx = min_idx + num_docs 339 pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 340 neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 341 pos_idcs = pos_mask.nonzero(as_tuple=True)[1] 342 neg_idcs = neg_mask.nonzero(as_tuple=True)[1] 343 if self.max_num_neg_samples is not None: 344 neg_idcs = neg_idcs.view(num_queries, -1) 345 if neg_idcs.shape[-1] > 1: 346 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])] 347 neg_idcs = neg_idcs[:, : self.max_num_neg_samples] 348 neg_idcs = neg_idcs.reshape(-1) 349 return pos_idcs, neg_idcs
350 351
[docs] 352class ScoreBasedInBatchLossFunction(InBatchLossFunction): 353
[docs] 354 def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = None): 355 super().__init__( 356 pos_sampling_technique="first", 357 neg_sampling_technique="all_and_non_first", 358 max_num_neg_samples=max_num_neg_samples, 359 ) 360 self.min_target_diff = min_target_diff
361 362 def _sort_mask( 363 self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch 364 ) -> torch.Tensor: 365 scores = self.process_scores(output) 366 targets = self.process_targets(scores, batch) 367 idcs = targets.argsort(descending=True).argsort().cpu() 368 idcs = idcs + torch.arange(num_queries)[:, None] * num_docs 369 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs 370 return mask.scatter(1, block_idcs, mask.gather(1, idcs)) 371 372 def _get_pos_mask( 373 self, 374 num_queries: int, 375 num_docs: int, 376 max_idx: torch.Tensor, 377 min_idx: torch.Tensor, 378 output: LightningIROutput, 379 batch: TrainBatch, 380 ) -> torch.Tensor: 381 pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 382 pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch) 383 return pos_mask 384 385 def _get_neg_mask( 386 self, 387 num_queries: int, 388 num_docs: int, 389 max_idx: torch.Tensor, 390 min_idx: torch.Tensor, 391 output: LightningIROutput, 392 batch: TrainBatch, 393 ) -> torch.Tensor: 394 neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 395 neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch) 396 scores = self.process_scores(output) 397 targets = self.process_targets(scores, batch).cpu() 398 max_score, _ = targets.max(dim=-1, keepdim=True) 399 score_diff = (max_score - targets).cpu() 400 score_mask = score_diff.ge(self.min_target_diff) 401 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs 402 neg_mask = neg_mask.scatter(1, block_idcs, score_mask) 403 # num_neg_samples might be different between queries 404 num_neg_samples = neg_mask.sum(dim=1) 405 min_num_neg_samples = num_neg_samples.min() 406 additional_neg_samples = num_neg_samples - min_num_neg_samples 407 for query_idx, neg_samples in enumerate(additional_neg_samples): 408 neg_idcs = neg_mask[query_idx].nonzero().squeeze(1) 409 additional_neg_idcs = neg_idcs[torch.randperm(neg_idcs.shape[0])][:neg_samples] 410 assert neg_mask[query_idx, additional_neg_idcs].all().item() 411 neg_mask[query_idx, additional_neg_idcs] = False 412 assert neg_mask[query_idx].sum().eq(min_num_neg_samples).item() 413 return neg_mask
414 415
[docs] 416class InBatchCrossEntropy(InBatchLossFunction): 417 def compute_loss(self, output: LightningIROutput) -> torch.Tensor: 418 scores = self.process_scores(output) 419 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) 420 loss = torch.nn.functional.cross_entropy(scores, targets) 421 return loss
422 423
[docs] 424class ScoreBasedInBatchCrossEntropy(ScoreBasedInBatchLossFunction): 425 426 def compute_loss(self, output: LightningIROutput) -> torch.Tensor: 427 scores = self.process_scores(output) 428 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) 429 loss = torch.nn.functional.cross_entropy(scores, targets) 430 return loss
431 432
[docs] 433class RegularizationLossFunction(EmbeddingLossFunction):
[docs] 434 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None: 435 self.query_weight = query_weight 436 self.doc_weight = doc_weight
437 438 def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]: 439 query_embeddings = output.query_embeddings 440 doc_embeddings = output.doc_embeddings 441 if query_embeddings is None: 442 raise ValueError("Expected query_embeddings in BiEncoderOutput") 443 if doc_embeddings is None: 444 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 445 return query_embeddings.embeddings, doc_embeddings.embeddings
446 447
[docs] 448class L2Regularization(RegularizationLossFunction): 449 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 450 query_embeddings, doc_embeddings = self.process_embeddings(output) 451 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean() 452 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean() 453 loss = query_loss + doc_loss 454 return loss
455 456
[docs] 457class L1Regularization(RegularizationLossFunction): 458 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 459 query_embeddings, doc_embeddings = self.process_embeddings(output) 460 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean() 461 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean() 462 loss = query_loss + doc_loss 463 return loss
464 465
[docs] 466class FLOPSRegularization(RegularizationLossFunction): 467 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 468 query_embeddings, doc_embeddings = self.process_embeddings(output) 469 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2) 470 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2) 471 anti_zero = 1 / (torch.sum(query_embeddings) ** 2) + 1 / (torch.sum(doc_embeddings) ** 2) 472 loss = self.query_weight * query_loss + self.doc_weight * doc_loss + anti_zero 473 return loss