Skip to content

[Refactor] Unify the CausalLM classes#3447

Open
babusid wants to merge 5 commits intomlc-ai:mainfrom
babusid:causal_lm_refactor
Open

[Refactor] Unify the CausalLM classes#3447
babusid wants to merge 5 commits intomlc-ai:mainfrom
babusid:causal_lm_refactor

Conversation

@babusid
Copy link
Contributor

@babusid babusid commented Mar 6, 2026

This refactor might be a bit tougher to do fully correctly, but the idea here is that we have a good amount of duplication across the CausalLM classes that we can reduce. I've implemented the more basic / easier ones so far, but some of the models will be trickier I think.

Not yet ready for merge, but I've marked it non-draft for gemini review

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request initiates a significant refactoring effort aimed at unifying the CausalLM classes within the project. By introducing a common base class and consolidating shared functionalities, the changes reduce code duplication, improve maintainability, and standardize the implementation of various language models. This is an initial step in a broader refactor, laying the groundwork for a more streamlined and consistent model architecture.

Highlights

  • Introduction of BaseForCausalLM: A new base class, BaseForCausalLM, was introduced in python/mlc_llm/nn/causal_lm.py to centralize common logic and reduce code duplication across various Causal Language Model implementations.
  • Refactoring of CausalLM Classes: Multiple model-specific ForCausalLM classes, including Baichuan, ChatGLM3, GPTBigCode, GPT-J, GPTNeoX, InternLM, InternLM2, Orion, StableLM, and Starcoder2, were refactored to inherit from the new BaseForCausalLM class.
  • Consolidation of Common Methods: Shared methods such as embed, prefill, decode, batch_forward, create_paged_kv_cache, and get_default_spec were moved from individual model implementations into the BaseForCausalLM, significantly reducing redundant code.
  • Parameter Mapping Updates: The internlm2_loader.py was updated to include a _name_transform function, ensuring correct parameter mapping for embed_tokens and lm_head in the refactored InternLM2ForCausalLM class.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/mlc_llm/model/baichuan/baichuan_model.py
    • Removed unused imports for te, op_ext, and RopeMode.
    • Imported BaseForCausalLM.
    • Updated BaichuanForCausalLM to inherit from BaseForCausalLM.
    • Removed duplicated methods and attributes, delegating them to the base class.
  • python/mlc_llm/model/chatglm3/chatglm3_model.py
    • Removed unused imports for te, op_ext, and RopeMode.
    • Imported BaseForCausalLM.
    • Updated ChatGLMForCausalLM to inherit from BaseForCausalLM.
    • Assigned self.transformer.output_layer to self.lm_head.
    • Removed duplicated methods and attributes, delegating them to the base class.
    • Added _get_backbone and _get_embed_module helper methods.
  • python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py
    • Removed unused import for te and op_ext.
    • Updated RopeMode import path.
    • Imported BaseForCausalLM.
    • Updated GPTBigCodeForCausalLM to inherit from BaseForCausalLM.
    • Adjusted model configuration attributes for consistency with the base class.
    • Removed duplicated methods and attributes, delegating them to the base class.
    • Added _embed_tokens, _get_backbone, and _get_embed_module helper methods.
  • python/mlc_llm/model/gpt_j/gpt_j_model.py
    • Removed unused imports for te and op_ext.
    • Imported BaseForCausalLM.
    • Updated GPTJForCausalLM to inherit from BaseForCausalLM.
    • Removed duplicated methods and attributes, delegating them to the base class.
    • Added _embed_tokens, _get_backbone, and _get_embed_module helper methods.
  • python/mlc_llm/model/gpt_neox/gpt_neox_model.py
    • Removed unused imports for te and op_ext.
    • Imported BaseForCausalLM.
    • Updated GPTNeoXForCausalLM to inherit from BaseForCausalLM.
    • Removed duplicated methods and attributes, delegating them to the base class.
    • Added rotary_dim, _embed_tokens, _get_backbone, _get_embed_module, and _get_lm_head helper methods.
  • python/mlc_llm/model/internlm/internlm_model.py
    • Removed unused imports for te, op_ext, and RopeMode.
    • Imported BaseForCausalLM.
    • Updated InternLMForCausalLM to inherit from BaseForCausalLM.
    • Removed duplicated methods and attributes, delegating them to the base class.
  • python/mlc_llm/model/internlm2/internlm2_loader.py
    • Added a _name_transform function to map parameter names.
    • Modified parameter mapping logic to use the new _name_transform function.
  • python/mlc_llm/model/internlm2/internlm2_model.py
    • Removed unused imports for te, op_ext, and RopeMode.
    • Imported BaseForCausalLM.
    • Renamed self.tok_embeddings to self.embed_tokens in InternLM2Model.
    • Updated InternLM2ForCausalLM to inherit from BaseForCausalLM.
    • Renamed self.output to self.lm_head.
    • Removed duplicated methods and attributes, delegating them to the base class.
  • python/mlc_llm/model/orion/orion_model.py
    • Removed unused imports for te, op_ext, and RopeMode.
    • Imported BaseForCausalLM.
    • Updated OrionForCausalLM to inherit from BaseForCausalLM.
    • Removed duplicated methods and attributes, delegating them to the base class.
  • python/mlc_llm/model/stable_lm/stablelm_model.py
    • Removed unused imports for te, op_ext, and RopeMode.
    • Imported BaseForCausalLM.
    • Updated StableLmForCausalLM to inherit from BaseForCausalLM.
    • Removed duplicated methods and attributes, delegating them to the base class.
    • Added rotary_dim attribute.
  • python/mlc_llm/model/starcoder2/starcoder2_model.py
    • Removed unused imports for te, op_ext, and RopeMode.
    • Imported BaseForCausalLM.
    • Updated Starcoder2ForCausalLM to inherit from BaseForCausalLM.
    • Removed duplicated methods and attributes, delegating them to the base class.
  • python/mlc_llm/nn/init.py
    • Added imports for BaseForCausalLM and index_last_token.
  • python/mlc_llm/nn/causal_lm.py
    • Added new file causal_lm.py.
    • Defined index_last_token function for extracting the last token slice.
    • Implemented BaseForCausalLM class with common methods for causal language models, including to, embed, batch_forward, prefill, decode, batch_prefill, batch_decode, batch_verify, create_paged_kv_cache, and get_default_spec.
Activity
  • The author, babusid, has marked this pull request as non-draft for Gemini review, indicating it's ready for initial feedback despite being an ongoing refactor.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable refactoring by creating a BaseForCausalLM class to unify the various Causal Language Model implementations. This change greatly reduces code duplication and improves maintainability. The common logic for forward passes, KV cache creation, and module specification is now centralized in the base class. The refactoring has been applied consistently across numerous models, with appropriate overrides for models that have unique naming conventions for their components. Overall, this is a solid architectural improvement. I have one suggestion to make the new base class even more robust.

babusid and others added 3 commits March 6, 2026 17:54
Extract only the truly shared methods (embed, prefill, decode,
batch_forward, etc.) into a base class. Model-specific logic
stays in each model file.
@babusid babusid force-pushed the causal_lm_refactor branch from f865f31 to 8ae622a Compare March 8, 2026 17:53
@babusid
Copy link
Contributor Author

babusid commented Mar 9, 2026

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great refactoring effort to unify the CausalLM classes by introducing a CausalLMABC base class. This significantly reduces code duplication across various model implementations and improves maintainability. The changes are consistent and well-structured, and I appreciate the cleanups like removing duplicated attribute assignments.

I've found one critical issue regarding the initialization of the dtype attribute which would cause a runtime error. Please see the detailed comment.

dtype: str
hidden_size: int
tensor_parallel_shards: int

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dtype attribute is used in create_paged_kv_cache and get_default_spec in subclasses, but it is not initialized in this base class. The original model implementations set self.dtype = "float32" in their __init__ methods, but this line was removed during refactoring and not added to the CausalLMABC. This will lead to an AttributeError if methods like get_default_spec are called before to(dtype=...) has been invoked.

Please initialize self.dtype in the base class __init__ to provide a default value.

Suggested change
def __init__(self):
super().__init__()
self.dtype = "float32"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant