-
Notifications
You must be signed in to change notification settings - Fork 525
Enable SwiGLU patching for Qwen3-VL #1175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -260,11 +260,12 @@ loss.backward() | |
| | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen3-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_vl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
|
Comment on lines
+263
to
+268
|
||
| | Qwen3.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_5` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | ||
| | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss | | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1789,7 +1789,7 @@ def apply_liger_kernel_to_qwen3_vl( | |
| cross_entropy: bool = False, | ||
| fused_linear_cross_entropy: bool = True, | ||
| rms_norm: bool = True, | ||
| swiglu: bool = False, | ||
| swiglu: bool = True, | ||
| model: PreTrainedModel = None, | ||
| ) -> None: | ||
| """ | ||
|
|
@@ -1802,7 +1802,7 @@ def apply_liger_kernel_to_qwen3_vl( | |
| `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. | ||
| If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. | ||
| rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. | ||
| swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. | ||
| swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. | ||
| model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been | ||
| loaded. Default is None. | ||
| """ | ||
|
|
@@ -1836,7 +1836,10 @@ def apply_liger_kernel_to_qwen3_vl( | |
| else: | ||
| modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward | ||
|
|
||
| if model is not None and rms_norm: | ||
| if swiglu: | ||
| modeling_qwen3_vl.Qwen3VLTextMLP = LigerSwiGLUMLP | ||
|
|
||
| if model is not None: | ||
| if isinstance(model, Qwen3VLForConditionalGeneration): | ||
| text_model: Qwen3VLTextModel = model.model.language_model | ||
| elif isinstance(model, Qwen3VLModel): | ||
|
|
@@ -1851,16 +1854,20 @@ def apply_liger_kernel_to_qwen3_vl( | |
| _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama") | ||
|
|
||
| if text_model is not None: | ||
| _patch_qwen3_vl_rms_norm(text_model.norm) | ||
| if rms_norm: | ||
| _patch_qwen3_vl_rms_norm(text_model.norm) | ||
|
Comment on lines
+1857
to
+1858
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great catch! |
||
| for decoder_layer in text_model.layers: | ||
| _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm) | ||
| _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm) | ||
| self_attn = getattr(decoder_layer, "self_attn", None) | ||
| if self_attn is not None: | ||
| if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: | ||
| _patch_qwen3_vl_rms_norm(self_attn.q_norm) | ||
| if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: | ||
| _patch_qwen3_vl_rms_norm(self_attn.k_norm) | ||
| if swiglu: | ||
| _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) | ||
| if rms_norm: | ||
| _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm) | ||
| _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm) | ||
| self_attn = getattr(decoder_layer, "self_attn", None) | ||
| if self_attn is not None: | ||
| if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None: | ||
| _patch_qwen3_vl_rms_norm(self_attn.q_norm) | ||
| if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None: | ||
| _patch_qwen3_vl_rms_norm(self_attn.k_norm) | ||
|
|
||
|
|
||
| def apply_liger_kernel_to_qwen3_vl_moe( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -554,6 +554,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( | |
| LigerRMSNorm.forward | ||
| ) | ||
| for decoder_layer in dummy_model_instance.model.language_model.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) | ||
| assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) | ||
| assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( | ||
| LigerRMSNorm.forward | ||
|
|
@@ -574,6 +575,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation( | |
| LigerRMSNorm.forward | ||
| ) | ||
| for decoder_layer in dummy_model_instance.model.language_model.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) | ||
| assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) | ||
| assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( | ||
| LigerRMSNorm.forward | ||
|
|
@@ -651,6 +653,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl(): | |
| LigerRMSNorm.forward | ||
| ) | ||
| for decoder_layer in dummy_model_instance.language_model.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) | ||
| assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) | ||
| assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( | ||
| LigerRMSNorm.forward | ||
|
|
@@ -671,6 +674,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl(): | |
| LigerRMSNorm.forward | ||
| ) | ||
| for decoder_layer in dummy_model_instance.language_model.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) | ||
| assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) | ||
| assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( | ||
| LigerRMSNorm.forward | ||
|
|
@@ -721,6 +725,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_text(): | |
| # Note: Text models don't have forward method patching, so skip this check | ||
| assert inspect.getsource(dummy_model_instance.norm.forward) != inspect.getsource(LigerRMSNorm.forward) | ||
| for decoder_layer in dummy_model_instance.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) | ||
| assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) | ||
| assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( | ||
| LigerRMSNorm.forward | ||
|
|
@@ -739,6 +744,7 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_text(): | |
| # Note: Text models don't have forward method patching, so skip this check | ||
| assert inspect.getsource(dummy_model_instance.norm.forward) == inspect.getsource(LigerRMSNorm.forward) | ||
| for decoder_layer in dummy_model_instance.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) | ||
| assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) | ||
| assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( | ||
| LigerRMSNorm.forward | ||
|
|
@@ -756,6 +762,47 @@ def test_apply_liger_kernel_to_instance_for_qwen3_vl_text(): | |
| pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_qwen3_vl_available(), reason="qwen3_vl module not available") | ||
| def test_apply_liger_kernel_to_qwen3_vl_swiglu_flag_patches_mlp(): | ||
| with patch("transformers.models.qwen3_vl.modeling_qwen3_vl"): | ||
| from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel | ||
|
|
||
| config = transformers.models.qwen3_vl.configuration_qwen3_vl.Qwen3VLTextConfig( | ||
| vocab_size=32000, | ||
| hidden_size=512, | ||
| intermediate_size=2048, | ||
| num_hidden_layers=2, | ||
| num_attention_heads=8, | ||
| num_key_value_heads=2, | ||
| head_dim=64, | ||
| hidden_act="silu", | ||
| max_position_embeddings=32768, | ||
| initializer_range=0.02, | ||
| rms_norm_eps=1e-6, | ||
| use_cache=False, | ||
| tie_word_embeddings=True, | ||
| attention_dropout=0.0, | ||
| attention_bias=False, | ||
| **get_qwen3_vl_rope_config(), | ||
| ) | ||
| dummy_model_instance = Qwen3VLTextModel._from_config(config) | ||
|
|
||
| for decoder_layer in dummy_model_instance.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) | ||
|
|
||
| monkey_patch.apply_liger_kernel_to_qwen3_vl( | ||
| model=dummy_model_instance, | ||
| rope=False, | ||
| cross_entropy=False, | ||
| fused_linear_cross_entropy=False, | ||
| rms_norm=False, | ||
| swiglu=True, | ||
| ) | ||
|
|
||
| for decoder_layer in dummy_model_instance.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) | ||
|
Comment on lines
+765
to
+803
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant test? adding swiglu check in above test functions should be sufficient. |
||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_qwen3_vl_moe_available(), reason="qwen3_vl_moe module not available") | ||
| def test_apply_liger_kernel_to_instance_for_qwen3_vl_moe_for_conditional_generation(): | ||
| # Ensure any monkey patching is cleaned up for subsequent tests | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch