55
66import torch
77
8- from torch .nn import CrossEntropyLoss
98from transformers .cache_utils import Cache
109from transformers .modeling_outputs import CausalLMOutputWithPast
1110from transformers .models .mistral .modeling_mistral import _CONFIG_FOR_DOC
1211from transformers .models .mistral .modeling_mistral import MISTRAL_INPUTS_DOCSTRING
1312from transformers .utils import add_start_docstrings_to_model_forward
1413from transformers .utils import replace_return_docstrings
14+ from transformers .utils .deprecation import deprecate_kwarg
1515
1616from liger_kernel .transformers .model .loss_utils import LigerForCausalLMLoss
1717
1818
19+ @deprecate_kwarg ("num_logits_to_keep" , version = "4.50" , new_name = "logits_to_keep" )
1920@add_start_docstrings_to_model_forward (MISTRAL_INPUTS_DOCSTRING )
2021@replace_return_docstrings (output_type = CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
2122def lce_forward (
@@ -31,6 +32,7 @@ def lce_forward(
3132 output_hidden_states : Optional [bool ] = None ,
3233 return_dict : Optional [bool ] = None ,
3334 cache_position : Optional [torch .LongTensor ] = None ,
35+ logits_to_keep : Union [int , torch .Tensor ] = 0 ,
3436 ** loss_kwargs ,
3537) -> Union [Tuple , CausalLMOutputWithPast ]:
3638 r"""
@@ -43,6 +45,12 @@ def lce_forward(
4345 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
4446 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
4547
48+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
49+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
50+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
51+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
52+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
53+ This is useful when using packed tensor format (single dimension for batch and sequence length).
4654 Returns:
4755
4856 Example:
@@ -97,21 +105,17 @@ def lce_forward(
97105 )
98106
99107 else :
100- logits = self .lm_head (hidden_states )
101- if labels is not None :
102- # Upcast to float if we need to compute the loss to avoid potential precision issues
103- logits = logits .float ()
104- # Shift so that tokens < n predict n
105- shift_logits = logits [..., :- 1 , :].contiguous ()
106- shift_labels = labels [..., 1 :].contiguous ()
107- # Flatten the tokens
108- shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
109- shift_labels = shift_labels .view (- 1 )
110- # Ensure tensors are on the same device
111- shift_labels = shift_labels .to (shift_logits .device )
112- loss_fct = CrossEntropyLoss ()
113- loss = loss_fct (shift_logits , shift_labels )
108+ slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
109+ logits = self .lm_head (hidden_states [:, slice_indices , :])
114110
111+ loss = None
112+ if labels is not None :
113+ loss = self .loss_function (
114+ logits = logits ,
115+ labels = labels ,
116+ vocab_size = self .config .vocab_size ,
117+ ** loss_kwargs ,
118+ )
115119 if not return_dict :
116120 output = (logits ,) + outputs [1 :]
117121 return (loss ,) + output if loss is not None else output
0 commit comments