From c36e542903ee77734f7e5d5d79a1d0265b2136b1 Mon Sep 17 00:00:00 2001 From: Dongcheng Ye Date: Wed, 1 Apr 2026 00:04:21 +0800 Subject: [PATCH] Enable SwiGLU patching for Qwen3-VL --- README.md | 11 +++-- src/liger_kernel/transformers/monkey_patch.py | 31 +++++++----- test/transformers/test_monkey_patch.py | 47 +++++++++++++++++++ 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index f7c916f1c..51dd43e56 100644 --- a/README.md +++ b/README.md @@ -260,11 +260,12 @@ loss.backward() | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | -| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | -| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | -| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | -| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | -| Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen3-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_vl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Qwen3.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_5` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss | diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 40c762a3e..5f4de30c3 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -1789,7 +1789,7 @@ def apply_liger_kernel_to_qwen3_vl( cross_entropy: bool = False, fused_linear_cross_entropy: bool = True, rms_norm: bool = True, - swiglu: bool = False, + swiglu: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -1802,7 +1802,7 @@ def apply_liger_kernel_to_qwen3_vl( `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. - swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -1836,7 +1836,10 @@ def apply_liger_kernel_to_qwen3_vl( else: modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward - if model is not None and rms_norm: + if swiglu: + modeling_qwen3_vl.Qwen3VLTextMLP = LigerSwiGLUMLP + + if model is not None: if isinstance(model, Qwen3VLForConditionalGeneration): text_model: Qwen3VLTextModel = model.model.language_model elif isinstance(model, Qwen3VLModel): @@ -1851,16 +1854,20 @@ def apply_liger_kernel_to_qwen3_vl( _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama") if text_model is not None: - _patch_qwen3_vl_rms_norm(text_model.norm) + if rms_norm: + _patch_qwen3_vl_rms_norm(text_model.norm) for decoder_layer in text_model.layers: - _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm) - _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm) - self_attn = getattr(decoder_layer, "self_attn", None) - if self_attn is not None: - if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: - _patch_qwen3_vl_rms_norm(self_attn.q_norm) - if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: - _patch_qwen3_vl_rms_norm(self_attn.k_norm) + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm) + _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm) + self_attn = getattr(decoder_layer, "self_attn", None) + if self_attn is not None: + if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: + _patch_qwen3_vl_rms_norm(self_attn.q_norm) + if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: + _patch_qwen3_vl_rms_norm(self_attn.k_norm) def apply_liger_kernel_to_qwen3_vl_moe( diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 542b45c76..5b76462d4 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -554,6 +554,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( LigerRMSNorm.forward ) for decoder_layer in dummy_model_instance.model.language_model.layers: + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( LigerRMSNorm.forward @@ -574,6 +575,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( LigerRMSNorm.forward ) for decoder_layer in dummy_model_instance.model.language_model.layers: + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( LigerRMSNorm.forward @@ -651,6 +653,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl(): LigerRMSNorm.forward ) for decoder_layer in dummy_model_instance.language_model.layers: + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( LigerRMSNorm.forward @@ -671,6 +674,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl(): LigerRMSNorm.forward ) for decoder_layer in dummy_model_instance.language_model.layers: + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( LigerRMSNorm.forward @@ -721,6 +725,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_text(): # Note: Text models don't have forward method patching, so skip this check assert inspect.getsource(dummy_model_instance.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for decoder_layer in dummy_model_instance.layers: + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( LigerRMSNorm.forward @@ -739,6 +744,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_text(): # Note: Text models don't have forward method patching, so skip this check assert inspect.getsource(dummy_model_instance.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for decoder_layer in dummy_model_instance.layers: + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( LigerRMSNorm.forward @@ -756,6 +762,47 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_text(): pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") +@pytest.mark.skipif(not is_qwen3_vl_available(), reason="qwen3_vl module not available") +def test_apply_liger_kernel_to_qwen3_vl_swiglu_flag_patches_mlp(): + with patch("transformers.models.qwen3_vl.modeling_qwen3_vl"): + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel + + config = transformers.models.qwen3_vl.configuration_qwen3_vl.Qwen3VLTextConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + tie_word_embeddings=True, + attention_dropout=0.0, + attention_bias=False, + **get_qwen3_vl_rope_config(), + ) + dummy_model_instance = Qwen3VLTextModel._from_config(config) + + for decoder_layer in dummy_model_instance.layers: + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + + monkey_patch.apply_liger_kernel_to_qwen3_vl( + model=dummy_model_instance, + rope=False, + cross_entropy=False, + fused_linear_cross_entropy=False, + rms_norm=False, + swiglu=True, + ) + + for decoder_layer in dummy_model_instance.layers: + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + + @pytest.mark.skipif(not is_qwen3_vl_moe_available(), reason="qwen3_vl_moe module not available") def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generation(): # Ensure any monkey patching is cleaned up for subsequent tests