Source code for lightning_ir.models.t5.model
1import torch
2from transformers import BatchEncoding
3
4from ...cross_encoder.model import CrossEncoderModel, CrossEncoderOutput
5from .config import T5CrossEncoderConfig
6
7
[docs]
8class ScaleLinear(torch.nn.Linear):
9
10 def forward(self, input: torch.Tensor) -> torch.Tensor:
11 # Rescale output before projecting on vocab
12 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa
13 input = input * (input.shape[-1] ** -0.5)
14 return super().forward(input)
15
16
[docs]
17class T5CrossEncoderModel(CrossEncoderModel):
18 config_class = T5CrossEncoderConfig
19
20 _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "linear.weight"]
21
[docs]
22 def __init__(self, config: T5CrossEncoderConfig, *args, **kwargs):
23 super().__init__(config, *args, **kwargs)
24 self.config: T5CrossEncoderConfig
25 if self.config.decoder_strategy == "mono":
26 self.linear = ScaleLinear(config.hidden_size, 2, bias=config.linear_bias)
27 else:
28 self.linear = ScaleLinear(config.hidden_size, 1, bias=config.linear_bias)
29
30 # TODO tieing of weights does not work when setting linear to only use slice of lm head for efficiency
31 # def get_output_embeddings(self):
32 # shared = self.shared
33 # if self.config.decoder_strategy == "mono":
34 # self.linear.weight.data = shared.weight.data[[1176, 6136]]
35 # elif self.config.decoder_strategy == "rank":
36 # self.linear.weight.data = shared.weight.data[[32089]]
37 # else:
38 # raise ValueError("Unknown decoder strategy")
39 # return shared
40
41 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
42 decoder_input_ids = torch.zeros(
43 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long
44 )
45 encoding["decoder_input_ids"] = decoder_input_ids
46 output = super().forward(encoding)
47 if output.scores is None:
48 raise ValueError("Scores are None")
49 if self.config.decoder_strategy == "mono":
50 scores = output.scores.view(-1, 2)
51 scores = torch.nn.functional.log_softmax(scores, dim=-1)[:, 0]
52 output.scores = scores.view(-1)
53 return output