Skip to content

Model shows incorrect module names if monkey patch is applied to the model instance. #625

@Tcc0403

Description

@Tcc0403

🐛 Describe the bug

In #524, @jp1924 found that the model doesn't show liger module names if monkey patch is applied to the model instance. It is because the current implementation is only binding Liger's forward and extra_repr methods to the instance without touching other methods.

Note that this bug doesn't affect the model training, but addressing it can be helpful for debugging.

See: torch.nn.Module, torch.nn.Module.__repr__, liger monkey patch

Reproduce

from transformers.models.llama import LlamaConfig
from transformers.models.llama import LlamaForCausalLM

from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama

mini_model_config = LlamaConfig(
    hidden_act="silu",
    hidden_size=1024,  # 4096
    initializer_range=0.02,
    intermediate_size=2048,  # 14336
    max_position_embeddings=8192,
    num_attention_heads=8,  # 32
    num_hidden_layers=4,  # 32
    num_key_value_heads=2,  # 8
    vocab_size=32000,  # 128256,
)
model = LlamaForCausalLM(mini_model_config)
print("Before monkey patch:")
print(model)
print("===============================================")


apply_liger_kernel_to_llama(model=model)
print("After monkey patch")
print(model)
output
Before monkey patch:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=256, bias=False)
          (v_proj): Linear(in_features=1024, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((1024,), eps=1e-06)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
)
===============================================
After monkey patch
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=256, bias=False)
          (v_proj): Linear(in_features=1024, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
      )
    )
    (norm): LlamaRMSNorm((1024,), eps=1e-06, offset=0.0, in_place=True)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
)

Versions

Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.10.12
Liger Kernel version: 0.5.4
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
HIP(ROCm) version: Not available
Triton version: 3.1.0
Transformers version: 4.49.0
XPU version: XPU Not Available

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions