Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,25 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
module.in_place = in_place
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
module.__class__.__name__ = LigerRMSNorm.__name__


def _patch_layer_norm_module(module, eps=1e-6):
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
module.hidden_size = module.normalized_shape
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
module.__class__.__name__ = LigerLayerNorm.__name__


def _patch_swiglu_module(module, liger_module):
_bind_method_to_module(module, "forward", liger_module.forward)
module.__class__.__name__ = liger_module.__name__


def _patch_geglu_module(module):
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
module.__class__.__name__ = LigerGEGLUMLP.__name__
Comment thread
vaibhavjindal marked this conversation as resolved.


def apply_liger_kernel_to_granite(
Expand Down Expand Up @@ -134,7 +146,7 @@ def apply_liger_kernel_to_granite(

for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -206,7 +218,7 @@ def apply_liger_kernel_to_llama(

for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -296,7 +308,7 @@ def apply_liger_kernel_to_mllama(
_patch_rms_norm_module(text_model.norm)
for decoder_layer in text_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -370,7 +382,7 @@ def apply_liger_kernel_to_mistral(

for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -442,7 +454,7 @@ def apply_liger_kernel_to_mixtral(
for decoder_layer in base_model.layers:
if swiglu:
for expert in decoder_layer.block_sparse_moe.experts:
_bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -516,7 +528,7 @@ def apply_liger_kernel_to_gemma(

for decoder_layer in base_model.layers:
if geglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
_patch_geglu_module(decoder_layer.mlp)
if rms_norm:
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -592,7 +604,7 @@ def apply_liger_kernel_to_gemma2(

for decoder_layer in base_model.layers:
if geglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
_patch_geglu_module(decoder_layer.mlp)
if rms_norm:
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -776,7 +788,7 @@ def apply_liger_kernel_to_qwen2(

for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -849,7 +861,7 @@ def apply_liger_kernel_to_qwen2_vl(
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -916,7 +928,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -985,7 +997,7 @@ def apply_liger_kernel_to_phi3(

for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
Expand Down Expand Up @@ -1048,7 +1060,7 @@ def apply_liger_kernel_to_olmo2(

for decoder_layer in base_model.layers:
if swiglu:
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
Expand Down
Loading