Skip to content

Commit a78a61d

Browse files
committed
Add rest
1 parent b85e556 commit a78a61d

8 files changed

Lines changed: 67 additions & 44 deletions

File tree

src/liger_kernel/transformers/model/gemma.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Union
55

66
import torch
7-
import torch.nn.functional as F
87

98
from torch.nn import CrossEntropyLoss
109
from transformers.cache_utils import Cache

src/liger_kernel/transformers/model/gemma2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Union
66

77
import torch
8-
import torch.nn.functional as F
98

109
from torch.nn import CrossEntropyLoss
1110
from transformers.cache_utils import HybridCache

src/liger_kernel/transformers/model/mistral.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55

66
import torch
77

8-
from torch.nn import CrossEntropyLoss
98
from transformers.cache_utils import Cache
109
from transformers.modeling_outputs import CausalLMOutputWithPast
1110
from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
1211
from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
1312
from transformers.utils import add_start_docstrings_to_model_forward
1413
from transformers.utils import replace_return_docstrings
14+
from transformers.utils.deprecation import deprecate_kwarg
1515

1616
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1717

1818

19+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
1920
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
2021
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
2122
def lce_forward(
@@ -31,6 +32,7 @@ def lce_forward(
3132
output_hidden_states: Optional[bool] = None,
3233
return_dict: Optional[bool] = None,
3334
cache_position: Optional[torch.LongTensor] = None,
35+
logits_to_keep: Union[int, torch.Tensor] = 0,
3436
**loss_kwargs,
3537
) -> Union[Tuple, CausalLMOutputWithPast]:
3638
r"""
@@ -43,6 +45,12 @@ def lce_forward(
4345
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
4446
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
4547
48+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
49+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
50+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
51+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
52+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
53+
This is useful when using packed tensor format (single dimension for batch and sequence length).
4654
Returns:
4755
4856
Example:
@@ -97,21 +105,17 @@ def lce_forward(
97105
)
98106

99107
else:
100-
logits = self.lm_head(hidden_states)
101-
if labels is not None:
102-
# Upcast to float if we need to compute the loss to avoid potential precision issues
103-
logits = logits.float()
104-
# Shift so that tokens < n predict n
105-
shift_logits = logits[..., :-1, :].contiguous()
106-
shift_labels = labels[..., 1:].contiguous()
107-
# Flatten the tokens
108-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
109-
shift_labels = shift_labels.view(-1)
110-
# Ensure tensors are on the same device
111-
shift_labels = shift_labels.to(shift_logits.device)
112-
loss_fct = CrossEntropyLoss()
113-
loss = loss_fct(shift_logits, shift_labels)
108+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
109+
logits = self.lm_head(hidden_states[:, slice_indices, :])
114110

111+
loss = None
112+
if labels is not None:
113+
loss = self.loss_function(
114+
logits=logits,
115+
labels=labels,
116+
vocab_size=self.config.vocab_size,
117+
**loss_kwargs,
118+
)
115119
if not return_dict:
116120
output = (logits,) + outputs[1:]
117121
return (loss,) + output if loss is not None else output

src/liger_kernel/transformers/model/mixtral.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
1313
from transformers.utils import add_start_docstrings_to_model_forward
1414
from transformers.utils import replace_return_docstrings
15+
from transformers.utils.deprecation import deprecate_kwarg
1516

1617
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1718
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -144,6 +145,7 @@ def lce_forward_deprecated(
144145
)
145146

146147

148+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
147149
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
148150
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
149151
# Ignore copy
@@ -161,7 +163,7 @@ def lce_forward(
161163
output_router_logits: Optional[bool] = None,
162164
return_dict: Optional[bool] = None,
163165
cache_position: Optional[torch.LongTensor] = None,
164-
num_logits_to_keep: int = 0,
166+
logits_to_keep: Union[int, torch.Tensor] = 0,
165167
**loss_kwargs,
166168
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
167169
r"""
@@ -171,10 +173,12 @@ def lce_forward(
171173
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
172174
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
173175
174-
num_logits_to_keep (`int`, *optional*):
175-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
176+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
177+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
176178
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
177179
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
180+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
181+
This is useful when using packed tensor format (single dimension for batch and sequence length).
178182
179183
Returns:
180184
@@ -235,15 +239,12 @@ def lce_forward(
235239
)
236240

237241
else: # if in inference mode materialize logits
238-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
239-
if labels is not None:
240-
loss = self.loss_function(
241-
logits=logits,
242-
labels=labels,
243-
vocab_size=self.config.vocab_size,
244-
**loss_kwargs,
245-
)
242+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
243+
logits = self.lm_head(hidden_states[:, slice_indices, :])
246244

245+
loss = None
246+
if labels is not None:
247+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
247248
aux_loss = None
248249
if output_router_logits:
249250
aux_loss = load_balancing_loss_func(

src/liger_kernel/transformers/model/mllama.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
1212
from transformers.utils import add_start_docstrings_to_model_forward
1313
from transformers.utils import replace_return_docstrings
14+
from transformers.utils.deprecation import deprecate_kwarg
1415

1516
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1617
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -133,6 +134,7 @@ def lce_forward_deprecated(
133134
)
134135

135136

137+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
136138
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
137139
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
138140
def lce_forward(
@@ -151,7 +153,7 @@ def lce_forward(
151153
output_hidden_states: Optional[bool] = None,
152154
return_dict: Optional[bool] = None,
153155
cache_position: Optional[torch.LongTensor] = None,
154-
num_logits_to_keep: int = 0,
156+
logits_to_keep: Union[int, torch.Tensor] = 0,
155157
**loss_kwargs,
156158
) -> Union[Tuple, CausalLMOutputWithPast]:
157159
r"""
@@ -161,10 +163,12 @@ def lce_forward(
161163
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
162164
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
163165
164-
num_logits_to_keep (`int`, *optional*):
165-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
166+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
167+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
166168
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
167169
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
170+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
171+
This is useful when using packed tensor format (single dimension for batch and sequence length).
168172
169173
Returns:
170174
@@ -225,7 +229,8 @@ def lce_forward(
225229
)
226230

227231
else: # if in inference mode materialize logits
228-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
232+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233+
logits = self.lm_head(hidden_states[:, slice_indices, :])
229234
if labels is not None:
230235
loss = self.loss_function(
231236
logits=logits,

src/liger_kernel/transformers/model/olmo2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
1111
from transformers.utils import add_start_docstrings_to_model_forward
1212
from transformers.utils import replace_return_docstrings
13+
from transformers.utils.deprecation import deprecate_kwarg
1314

1415
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1516

1617

18+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
1719
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
1820
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1921
def lce_forward(
@@ -29,7 +31,7 @@ def lce_forward(
2931
output_hidden_states: Optional[bool] = None,
3032
return_dict: Optional[bool] = None,
3133
cache_position: Optional[torch.LongTensor] = None,
32-
num_logits_to_keep: int = 0,
34+
logits_to_keep: Union[int, torch.Tensor] = 0,
3335
**loss_kwargs,
3436
) -> Union[Tuple, CausalLMOutputWithPast]:
3537
r"""
@@ -39,10 +41,12 @@ def lce_forward(
3941
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
4042
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
4143
42-
num_logits_to_keep (`int`, *optional*):
43-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
44+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
45+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
4446
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
4547
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49+
This is useful when using packed tensor format (single dimension for batch and sequence length).
4650
4751
Returns:
4852
@@ -98,7 +102,8 @@ def lce_forward(
98102
)
99103

100104
else: # if in inference mode materialize logits
101-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
105+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106+
logits = self.lm_head(hidden_states[:, slice_indices, :])
102107
if labels is not None:
103108
loss = self.loss_function(
104109
logits=logits,

src/liger_kernel/transformers/model/phi3.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
1212
from transformers.utils import add_start_docstrings_to_model_forward
1313
from transformers.utils import replace_return_docstrings
14+
from transformers.utils.deprecation import deprecate_kwarg
1415

1516
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1617
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -126,6 +127,7 @@ def lce_forward_deprecated(
126127
)
127128

128129

130+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
129131
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
130132
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
131133
def lce_forward(
@@ -141,7 +143,7 @@ def lce_forward(
141143
output_hidden_states: Optional[bool] = None,
142144
return_dict: Optional[bool] = None,
143145
cache_position: Optional[torch.LongTensor] = None,
144-
num_logits_to_keep: int = 0,
146+
logits_to_keep: Union[int, torch.Tensor] = 0,
145147
**loss_kwargs,
146148
) -> Union[Tuple, CausalLMOutputWithPast]:
147149
r"""
@@ -151,10 +153,12 @@ def lce_forward(
151153
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
152154
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
153155
154-
num_logits_to_keep (`int`, *optional*):
155-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
156+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
157+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
156158
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
157159
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
160+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
161+
This is useful when using packed tensor format (single dimension for batch and sequence length).
158162
159163
Returns:
160164
@@ -223,7 +227,8 @@ def lce_forward(
223227
)
224228

225229
else: # if in inference mode materialize logits
226-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
230+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231+
logits = self.lm_head(hidden_states[:, slice_indices, :])
227232
if labels is not None:
228233
loss = self.loss_function(
229234
logits=logits,

src/liger_kernel/transformers/model/qwen2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
1212
from transformers.utils import add_start_docstrings_to_model_forward
1313
from transformers.utils import replace_return_docstrings
14+
from transformers.utils.deprecation import deprecate_kwarg
1415

1516
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1617
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -125,6 +126,7 @@ def lce_forward_deprecated(
125126
)
126127

127128

129+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
128130
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
129131
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
130132
def lce_forward(
@@ -140,7 +142,7 @@ def lce_forward(
140142
output_hidden_states: Optional[bool] = None,
141143
return_dict: Optional[bool] = None,
142144
cache_position: Optional[torch.LongTensor] = None,
143-
num_logits_to_keep: int = 0,
145+
logits_to_keep: Union[int, torch.Tensor] = 0,
144146
**loss_kwargs,
145147
) -> Union[Tuple, CausalLMOutputWithPast]:
146148
r"""
@@ -150,10 +152,12 @@ def lce_forward(
150152
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
151153
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
152154
153-
num_logits_to_keep (`int`, *optional*):
154-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
155+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
156+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
155157
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
156158
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
159+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
160+
This is useful when using packed tensor format (single dimension for batch and sequence length).
157161
158162
Returns:
159163
@@ -209,7 +213,8 @@ def lce_forward(
209213
)
210214

211215
else: # if in inference mode materialize logits
212-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
216+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217+
logits = self.lm_head(hidden_states[:, slice_indices, :])
213218
if labels is not None:
214219
loss = self.loss_function(
215220
logits=logits,

0 commit comments

Comments
 (0)