Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2f31352
gemma4: implementation plan
amrtini Apr 16, 2026
982e1cb
gemma4: narrow plan scope to 31B text-only
amrtini Apr 16, 2026
cd21e6b
gemma4: assimilate user's deployment spec into the plan
amrtini Apr 16, 2026
7996fa0
gemma4: mirror full gemma3 test coverage, not just one test file
amrtini Apr 16, 2026
3c69790
gemma4: add LigerRMSNormForGemma4 (ones init, no +1 offset)
amrtini Apr 16, 2026
e83f09f
gemma4: scaffold causal_forward for Gemma4ForCausalLM
amrtini Apr 16, 2026
39422cc
gemma4: add apply_liger_kernel_to_gemma4_text + model-type registration
amrtini Apr 16, 2026
562293f
gemma4: align monkey_patch function style with gemma3 sibling
amrtini Apr 16, 2026
86fa3f1
gemma4: export apply_liger_kernel_to_gemma4_text from package root
amrtini Apr 16, 2026
cfac774
gemma4: add revert_liger_kernel_to_gemma4_text test helper
amrtini Apr 16, 2026
8f33299
gemma4: add mini_gemma4_text bf16 convergence test
amrtini Apr 16, 2026
af35978
gemma4: add mini_gemma4_text to bf16 logits-parity convergence test
amrtini Apr 16, 2026
537932f
gemma4: add test_apply_liger_kernel_to_instance_for_gemma4_text
amrtini Apr 16, 2026
101df25
gemma4: polish comments from whole-branch review
amrtini Apr 16, 2026
1709404
gemma4: handle with_scale=False (v_norm) in RMSNorm subclass
amrtini Apr 16, 2026
f35647f
gemma4: wrap LigerGEGLUMLP to absorb Gemma4TextMLP's layer_idx arg
amrtini Apr 16, 2026
d93ec5e
gemma4: skip rope kernel swap (signature incompatibility with HF Gemm…
amrtini Apr 16, 2026
ac0c794
gemma4: loosen bf16 convergence tolerances for 6-layer mini model
amrtini Apr 16, 2026
650f9bb
gemma4: also bump logprobs_atol in test_mini_models.py
amrtini Apr 16, 2026
6d2470e
gemma4: PR description with LUMI-measured numbers
amrtini Apr 16, 2026
4b8979d
gemma4: address pre-PR review findings
amrtini Apr 17, 2026
2a9152b
gemma4: align PR description with maintainer AI-assisted conventions
amrtini Apr 17, 2026
58c881a
gemma4: drop internal planning docs (not for upstream)
amrtini Apr 17, 2026
837f442
address review: double-wide MLP, gemma4 model type, rope default
amrtini Apr 23, 2026
3a4ccae
test: add Gemma 4 double-wide MLP edge case tests
amrtini Apr 23, 2026
afc113a
review: remove gemma4 model type mapping (defer to multimodal PR)
amrtini Apr 24, 2026
d7c5a62
style: fix import sort and assert formatting in test_geglu.py
amrtini Apr 25, 2026
9c81e46
[Gemma 4] Add multimodal support (apply_liger_kernel_to_gemma4 for Ge…
dvdimitrov13 Apr 26, 2026
e50fb42
fix: filter text_classes by isinstance(cls, type) under unittest.mock…
dvdimitrov13 Apr 26, 2026
ae02d34
style: ruff format
dvdimitrov13 Apr 26, 2026
73f941f
fix: same isinstance(cls, type) filter in apply_liger_kernel_to_gemma…
dvdimitrov13 Apr 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ loss.backward()
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma4 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma4_text` | RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma4 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma4` | RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Expand Down
6 changes: 6 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma4_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
Expand Down Expand Up @@ -117,6 +119,8 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_gemma4",
"apply_liger_kernel_to_gemma4_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
Expand Down Expand Up @@ -203,6 +207,8 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_gemma4",
"apply_liger_kernel_to_gemma4_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
Expand Down
25 changes: 25 additions & 0 deletions src/liger_kernel/transformers/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,28 @@ def __init__(self, config):

def forward(self, x):
return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))


class LigerGEGLUMLPForGemma4(LigerGEGLUMLP):
"""GEGLU MLP wrapper matching Gemma4TextMLP's (config, layer_idx) constructor.

HF's Gemma4TextMLP conditionally doubles intermediate_size for KV-shared layers
when ``config.use_double_wide_mlp=True``. This subclass replicates that logic
so the class-level swap works for all Gemma 4 variants (31B text, future MoE).

See: https://github.com/huggingface/transformers/blob/74a2a4d0c/src/transformers/models/gemma4/modeling_gemma4.py#L1030-L1035
"""

def __init__(self, config, layer_idx=None):
super().__init__(config)
# Match HF's conditional doubling for KV-shared layers
if layer_idx is not None and getattr(config, "use_double_wide_mlp", False):
num_hidden = getattr(config, "num_hidden_layers", 0)
num_kv_shared = getattr(config, "num_kv_shared_layers", 0)
first_kv_shared = num_hidden - num_kv_shared
if num_kv_shared > 0 and layer_idx >= first_kv_shared:
doubled = config.intermediate_size * 2
self.intermediate_size = doubled
self.gate_proj = nn.Linear(self.hidden_size, doubled, bias=False)
self.up_proj = nn.Linear(self.hidden_size, doubled, bias=False)
self.down_proj = nn.Linear(doubled, self.hidden_size, bias=False)
297 changes: 297 additions & 0 deletions src/liger_kernel/transformers/model/gemma4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from transformers.cache_utils import Cache
from transformers.utils import logging

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast

try:
from liger_kernel.transformers.model.output_classes import LigerGemma4CausalLMOutputWithPast
except ImportError:
# Older transformers without gemma4 — multimodal_forward is then unreachable
# because monkey_patch.apply_liger_kernel_to_gemma4 imports gemma4 modules
# behind the same try/except.
LigerGemma4CausalLMOutputWithPast = None

logger = logging.get_logger(__name__)


def causal_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Fused-linear-cross-entropy forward for Gemma4ForCausalLM. Mirrors liger's
gemma3 causal_forward. Gemma 4 31B uses final_logit_softcapping=30.0, so
the softcap branch is exercised on the non-fused path.

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, Gemma4ForCausalLM

>>> model = Gemma4ForCausalLM.from_pretrained("google/gemma-4-31b") # illustrative slug
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-4-31b")

>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""

if self.training and self.config._attn_implementation != "eager":
logger.warning_once(
"It is strongly recommended to train Gemma4 models with the `eager` attention implementation "
f"instead of `{self.config._attn_implementation}`. Use `eager` with "
"`AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**loss_kwargs,
)

hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = loss_kwargs.pop("shift_labels", None)
loss = None
logits = None
token_accuracy = None
predicted_tokens = None

if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
# final_logit_softcapping via getattr: some future Gemma 4 variants may omit the attribute entirely.
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
final_logit_softcapping=getattr(self.config, "final_logit_softcapping", None),
**loss_kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
final_logit_softcapping = getattr(self.config, "final_logit_softcapping", None)
if final_logit_softcapping is not None:
logits = logits / final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * final_logit_softcapping
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.vocab_size,
**loss_kwargs,
)

if not return_dict:
output_tuple = (logits,) + outputs[1:]
output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
output_tuple = output_tuple + (predicted_tokens,) if predicted_tokens is not None else output_tuple
return output_tuple

return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)


def multimodal_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
input_features_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_position_ids: Optional[torch.LongTensor] = None,
video_position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
mm_token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**lm_kwargs,
):
r"""Fused-linear-cross-entropy forward for ``Gemma4ForConditionalGeneration``.

Mirrors :func:`liger_kernel.transformers.model.gemma3.multimodal_forward`
with Gemma 4-specific kwargs (``pixel_values_videos``, ``input_features``,
``image_position_ids``, ``video_position_ids``, ``mm_token_type_ids``) and
output fields (``image_hidden_states``, ``audio_hidden_states``).

The win on Gemma 4 multimodal is large: vocab=262,144 means the (B, T, V)
fp32 logits tensor is ~17 GB at T=8192 in bf16 (and another ~34 GB once the
loss path upcasts), OOMing 96 GB cards. Routing loss through
``LigerForCausalLMLoss`` materializes only the loss scalar.

labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either
be in `[0, ..., config.text_config.vocab_size]` or -100 (see `input_ids`
docstring). Tokens with indices set to `-100` are ignored.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`,
calculate logits for all `input_ids` (special case). If a `torch.Tensor`,
must be 1D corresponding to the indices to keep in the sequence-length
dimension.
"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
attention_mask=attention_mask,
input_features_mask=input_features_mask,
position_ids=position_ids,
image_position_ids=image_position_ids,
video_position_ids=video_position_ids,
past_key_values=past_key_values,
mm_token_type_ids=mm_token_type_ids,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)

hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

text_cfg = self.config.get_text_config()
softcap = getattr(text_cfg, "final_logit_softcapping", None)
shift_labels = lm_kwargs.pop("shift_labels", None)

loss = None
logits = None
token_accuracy = None
predicted_tokens = None

if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
result = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=text_cfg.hidden_size,
final_logit_softcapping=softcap,
**lm_kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
else:
logits = self.lm_head(kept_hidden_states)
if softcap is not None:
logits = logits / softcap
logits = torch.tanh(logits)
logits = logits * softcap
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=text_cfg.vocab_size,
**lm_kwargs,
)

if not return_dict:
output_tuple = (logits,) + outputs[1:]
output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
output_tuple = output_tuple + (predicted_tokens,) if predicted_tokens is not None else output_tuple
return output_tuple

return LigerGemma4CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=getattr(outputs, "image_hidden_states", None),
audio_hidden_states=getattr(outputs, "audio_hidden_states", None),
token_accuracy=token_accuracy,
predicted_tokens=predicted_tokens,
)
13 changes: 13 additions & 0 deletions src/liger_kernel/transformers/model/output_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
except Exception:
_Gemma3CausalLMOutputWithPast = None

try:
from transformers.models.gemma4.modeling_gemma4 import Gemma4CausalLMOutputWithPast as _Gemma4CausalLMOutputWithPast
except Exception:
_Gemma4CausalLMOutputWithPast = None

try:
from transformers.models.glm4v_moe.modeling_glm4v_moe import (
Glm4vMoeCausalLMOutputWithPast as _Glm4vMoeCausalLMOutputWithPast,
Expand Down Expand Up @@ -101,6 +106,14 @@ class LigerGemma3CausalLMOutputWithPast(_Gemma3CausalLMOutputWithPast):
predicted_tokens: Optional[torch.LongTensor] = None


if _Gemma4CausalLMOutputWithPast is not None:

@dataclass
class LigerGemma4CausalLMOutputWithPast(_Gemma4CausalLMOutputWithPast):
token_accuracy: Optional[torch.FloatTensor] = None
predicted_tokens: Optional[torch.LongTensor] = None


if _Glm4vMoeCausalLMOutputWithPast is not None:

@dataclass
Expand Down
Loading