Skip to content

Fix incorrect module name when monkey_patch applied to instantiated model#629

Merged
vaibhavjindal merged 1 commit intomainfrom
fix-module-name
Mar 26, 2025
Merged

Fix incorrect module name when monkey_patch applied to instantiated model#629
vaibhavjindal merged 1 commit intomainfrom
fix-module-name

Conversation

@vaibhavjindal
Copy link
Copy Markdown
Collaborator

@vaibhavjindal vaibhavjindal commented Mar 25, 2025

Fix incorrect module name when monkey_patch applied to instantiated model.

Summary

Fixes: #625.
This PR sets the module.__class__.__name__ correctly which allows the expr to correctly show the names of the liger modules.

Results

Upon running the code as mentioned here, here is the output after this change:

Before monkey patch:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=256, bias=False)
          (v_proj): Linear(in_features=1024, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((1024,), eps=1e-06)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
)
===============================================
After monkey patch
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=256, bias=False)
          (v_proj): Linear(in_features=1024, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LigerSwiGLUMLP(
          (gate_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LigerRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
        (post_attention_layernorm): LigerRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
      )
    )
    (norm): LigerRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
)

Note

This output is not exactly equal to what you will get if you apply monkey patch before instantiating the model. For example, consider the following code:

from transformers.models.llama import LlamaConfig
from transformers.models.llama import LlamaForCausalLM

from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()

mini_model_config = LlamaConfig(
    hidden_act="silu",
    hidden_size=1024,  # 4096
    initializer_range=0.02,
    intermediate_size=2048,  # 14336
    max_position_embeddings=8192,
    num_attention_heads=8,  # 32
    num_hidden_layers=4,  # 32
    num_key_value_heads=2,  # 8
    vocab_size=32000,  # 128256,
)
model = LlamaForCausalLM(mini_model_config)
print(model)

This code generates the following output:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=256, bias=False)
          (v_proj): Linear(in_features=1024, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LigerSwiGLUMLP(
          (gate_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=1024, bias=False)
        )
        (input_layernorm): LigerRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
        (post_attention_layernorm): LigerRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
      )
    )
    (norm): LigerRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
)

Notice that the mlp layer is a bit different in this case as compared to when monkey-patch was done on the instantiated model. Here, the mlp layer does not contain the (act_fn): SiLU() in the output. This is a known limitation for this approach as we are just changing the __class__.__name__ and not the internal working of the __expr__ method.

Testing Done

pre merge tests

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@vaibhavjindal vaibhavjindal changed the title Fix incorrect module name when monkey_patch applied to instantiated m… Fix incorrect module name when monkey_patch applied to instantiated model Mar 25, 2025
@vaibhavjindal vaibhavjindal marked this pull request as ready for review March 26, 2025 00:18
Copy link
Copy Markdown
Collaborator

@shivam15s shivam15s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Comment thread src/liger_kernel/transformers/monkey_patch.py
@vaibhavjindal vaibhavjindal merged commit 3a5845b into main Mar 26, 2025
6 of 8 checks passed
@vaibhavjindal vaibhavjindal deleted the fix-module-name branch March 26, 2025 01:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Model shows incorrect module names if monkey patch is applied to the model instance.

2 participants