Skip to content

Commit a8fc81b

Browse files
Add support for InternLM2 model architecture (#1958)
# Add InternLM2 Model Support Adds full support for InternLM2 model family (1.8B, 7B, etc.) to ONNX Runtime GenAI. ## Changes ### Core Implementation - **New InternLM2Model builder** (`src/python/py/models/builders/internlm.py`) - Extends LlamaModel with InternLM2-specific weight mapping - GQA support: 16 query heads, 8 KV heads (2:1 ratio) - Proper grouped QKV weight splitting for GroupQueryAttention operator - **Model registration** (`builder.py`, `__init__.py`, `model_type.h`) - Maps `InternLM2ForCausalLM` → `InternLM2Model` - Adds "internlm2" to supported model types ### Tokenizer Support - **Upstream**: Contributed InternLM2Tokenizer support to [onnxruntime-extensions#1023](microsoft/onnxruntime-extensions#1023) (merged) - **Dependencies**: - Updated `cmake/deps.txt` to onnxruntime-extensions commit `087953cd` - Removed local patch in `cmake/external/onnxruntime_external_deps.cmake` - **Fix**: Set correct `model_max_length` in tokenizer_config.json (prevents 1e30 invalid values) ### Documentation - Updated README.md and src/python/py/models/README.md ## Usage Export ``` python -m onnxruntime_genai.models.builder \ --model_name internlm/internlm2-1_8b \ --output ./internlm2-cpu-int4 \ --precision int4 \ --execution_provider cpu ``` Inference ``` import onnxruntime_genai as og model = og.Model("./internlm2-cpu-int4") tokenizer = og.Tokenizer(model) ... standard generation code ``` ## Testing - ✅ InternLM2-1.8B INT4 CPU: export and inference - ✅ InternLM2-7B INT4 CPU: export tested - ✅ GQA weight splitting verified - ✅ Tokenizer recognition working ## References - Model: https://huggingface.co/internlm/internlm2-1_8b - Upstream PR: microsoft/onnxruntime-extensions#1023 --------- Signed-off-by: Rajeev Patwari <rajeevp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent b262497 commit a8fc81b

File tree

8 files changed

+141
-4
lines changed

8 files changed

+141
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ See documentation at the [ONNX Runtime website](https://onnxruntime.ai/docs/gena
1616

1717
| Support matrix | Supported now | Under development | On the roadmap|
1818
| -------------- | ------------- | ----------------- | -------------- |
19-
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Fara <br/> Gemma <br/> gpt-oss <br/> Granite <br/> Llama <br/> Mistral <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen (language + vision) <br/> SmolLM3 <br/> Whisper | Stable diffusion | Multi-modal models |
19+
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Fara <br/> Gemma <br/> gpt-oss <br/> Granite <br/> InternLM2 <br/> Llama <br/> Mistral <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen (language + vision) <br/> SmolLM3 <br/> Whisper | Stable diffusion | Multi-modal models |
2020
| API| Python <br/>C# <br/>C/C++ <br/> Java ^ | Objective-C ||
2121
| O/S | Linux <br/> Windows <br/>Mac <br/>Android || iOS |||
2222
| Architecture | x86 <br/> x64 <br/> arm64 ||||

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
1414
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
1515
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
1616
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
17-
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;2fbe0ebbb3eb21199ab74c92b6edf3804d827998
17+
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;087953cde6149e423c6848c40c3791264272706c
1818

1919
# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
2020
llguidance;https://github.com/microsoft/llguidance.git;94fa39128ef184ffeda33845f6d333f332a34b4d

src/models/model_type.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
3+
// Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved.
4+
// --------------------------------------------------------------------------
5+
36
#pragma once
47

58
#include <algorithm>
@@ -12,7 +15,7 @@ namespace Generators {
1215
struct ModelType {
1316
inline static bool IsLLM(const std::string& model_type) {
1417
// Large-language model (LLM)
15-
static constexpr std::array<std::string_view, 20> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "gptoss", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
18+
static constexpr std::array<std::string_view, 21> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "gptoss", "granite", "internlm2", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
1619
return std::find(LLM.begin(), LLM.end(), model_type) != LLM.end();
1720
}
1821

src/python/py/models/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ The tool currently supports the following model architectures.
4141
- Gemma
4242
- gpt-oss
4343
- Granite
44+
- InternLM2
4445
- Llama
4546
- Mistral
4647
- Nemotron
4748
- Phi
4849
- Qwen
4950
- SmolLM3
5051

52+
5153
It is intended for supporting the latest, popular state-of-the-art models.
5254

5355
## Usage

src/python/py/models/builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6+
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content.
7+
# --------------------------------------------------------------------------
68
"""
79
Run the model builder to create the desired ONNX model.
810
"""
@@ -22,6 +24,7 @@
2224
GemmaModel,
2325
GPTOSSModel,
2426
GraniteModel,
27+
InternLM2Model,
2528
LlamaModel,
2629
MistralModel,
2730
Model,
@@ -244,6 +247,8 @@ def create_model(
244247
onnx_model = GPTOSSModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
245248
elif config.architectures[0] == "GraniteForCausalLM":
246249
onnx_model = GraniteModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
250+
elif config.architectures[0] == "InternLM2ForCausalLM":
251+
onnx_model = InternLM2Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
247252
elif config.architectures[0] == "LlamaForCausalLM":
248253
onnx_model = LlamaModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
249254
elif config.architectures[0] == "MistralForCausalLM":

src/python/py/models/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6+
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content.
7+
# --------------------------------------------------------------------------
68
from .base import Model
79
from .chatglm import ChatGLMModel
810
from .ernie import ErnieModel
911
from .gemma import Gemma2Model, Gemma3Model, GemmaModel
1012
from .gptoss import GPTOSSModel
1113
from .granite import GraniteModel
14+
from .internlm import InternLM2Model
1215
from .llama import LlamaModel
1316
from .mistral import MistralModel
1417
from .nemotron import NemotronModel
@@ -34,6 +37,7 @@
3437
"Gemma3Model",
3538
"GemmaModel",
3639
"GraniteModel",
40+
"InternLM2Model",
3741
"LlamaModel",
3842
"MistralModel",
3943
"Model",

src/python/py/models/builders/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
#
6-
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
6+
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content.
77
# --------------------------------------------------------------------------
88
from __future__ import annotations
99

@@ -671,6 +671,9 @@ def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
671671
tokenizer = AutoTokenizer.from_pretrained(
672672
model_name_or_path, token=self.hf_token, trust_remote_code=self.hf_remote, **extra_kwargs
673673
)
674+
# Overwrite model_max_length with the model's context_length so it is a normal integer
675+
# (HF often uses 1e30 for "no limit", which can serialize to a huge decimal in JSON)
676+
tokenizer.model_max_length = self.context_length
674677
print(f"Saving processing files in {out_dir} for GenAI")
675678
tokenizer.save_pretrained(out_dir)
676679

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
from .base import Model
7+
import torch.nn as nn
8+
9+
10+
class InternLM2Model(Model):
11+
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
12+
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
13+
# Export genai_config with type "internlm2" (C++ model_type.h already lists "internlm2" as LLM)
14+
self.model_type = "InternLM2ForCausalLM"
15+
16+
def load_weights(self, input_path):
17+
"""
18+
Load the InternLM2 model and adapt attribute names to match base class expectations.
19+
InternLM2 uses:
20+
- attention_norm instead of input_layernorm
21+
- ffn_norm instead of post_attention_layernorm
22+
- feed_forward instead of mlp
23+
- wqkv (combined QKV) instead of separate q_proj, k_proj, v_proj
24+
- wo instead of o_proj
25+
"""
26+
# Load the model using the parent class method
27+
model = super().load_weights(input_path)
28+
29+
# Get config from the loaded model
30+
config = model.config
31+
32+
# Adapt each decoder layer to match the expected attribute names
33+
for layer in model.model.layers:
34+
# Map attention_norm to input_layernorm
35+
if hasattr(layer, 'attention_norm') and not hasattr(layer, 'input_layernorm'):
36+
layer.input_layernorm = layer.attention_norm
37+
38+
# Map ffn_norm to post_attention_layernorm
39+
if hasattr(layer, 'ffn_norm') and not hasattr(layer, 'post_attention_layernorm'):
40+
layer.post_attention_layernorm = layer.ffn_norm
41+
42+
# Map feed_forward to mlp
43+
if hasattr(layer, 'feed_forward') and not hasattr(layer, 'mlp'):
44+
layer.mlp = layer.feed_forward
45+
46+
# Map attention to self_attn
47+
if hasattr(layer, 'attention') and not hasattr(layer, 'self_attn'):
48+
layer.self_attn = layer.attention
49+
50+
# Map MLP projections (w1/w2/w3 to gate_proj/down_proj/up_proj)
51+
if hasattr(layer.mlp, 'w1') and not hasattr(layer.mlp, 'gate_proj'):
52+
layer.mlp.gate_proj = layer.mlp.w1
53+
if hasattr(layer.mlp, 'w2') and not hasattr(layer.mlp, 'down_proj'):
54+
layer.mlp.down_proj = layer.mlp.w2
55+
if hasattr(layer.mlp, 'w3') and not hasattr(layer.mlp, 'up_proj'):
56+
layer.mlp.up_proj = layer.mlp.w3
57+
58+
# Handle the combined wqkv projection in attention
59+
# InternLM2 uses a grouped/interleaved layout: [Q1, Q2, ..., Qn, K, V] per KV group
60+
# Layout: [batch, seq, num_kv_heads, (num_q_heads_per_kv_group + 2), head_dim]
61+
if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'wqkv'):
62+
attn = layer.self_attn
63+
wqkv_weight = attn.wqkv.weight # Shape: [(num_heads + 2*num_kv_heads) * head_dim, hidden_size]
64+
wqkv_bias = attn.wqkv.bias if hasattr(attn.wqkv, 'bias') and attn.wqkv.bias is not None else None
65+
66+
# Calculate dimensions
67+
num_q_heads = config.num_attention_heads
68+
num_kv_heads = config.num_key_value_heads
69+
num_kv_groups = num_q_heads // num_kv_heads # How many Q heads per KV head
70+
head_dim = config.hidden_size // num_q_heads
71+
72+
q_size = num_q_heads * head_dim
73+
kv_size = num_kv_heads * head_dim
74+
75+
# InternLM2's wqkv is organized as interleaved groups:
76+
# For each KV head group: [Q_heads for this group (num_kv_groups heads), K for this group, V for this group]
77+
# We need to reshape and reorder to standard [all Q | all K | all V] layout
78+
79+
# Reshape to grouped format: [num_kv_heads, (num_kv_groups + 2), head_dim, hidden_size]
80+
group_size = num_kv_groups + 2
81+
wqkv_grouped = wqkv_weight.reshape(num_kv_heads, group_size, head_dim, config.hidden_size)
82+
83+
# Extract Q, K, V from grouped layout
84+
# Q heads: first num_kv_groups entries in each group
85+
q_weight = wqkv_grouped[:, :num_kv_groups, :, :].reshape(num_q_heads, head_dim, config.hidden_size)
86+
q_weight = q_weight.reshape(q_size, config.hidden_size)
87+
88+
# K heads: second to last entry in each group
89+
k_weight = wqkv_grouped[:, -2, :, :].reshape(kv_size, config.hidden_size)
90+
91+
# V heads: last entry in each group
92+
v_weight = wqkv_grouped[:, -1, :, :].reshape(kv_size, config.hidden_size)
93+
94+
# Create separate projection layers
95+
attn.q_proj = nn.Linear(config.hidden_size, q_size, bias=config.bias)
96+
attn.k_proj = nn.Linear(config.hidden_size, kv_size, bias=config.bias)
97+
attn.v_proj = nn.Linear(config.hidden_size, kv_size, bias=config.bias)
98+
99+
# Copy weights (ensure proper copy and contiguous memory)
100+
attn.q_proj.weight.data.copy_(q_weight.contiguous())
101+
attn.k_proj.weight.data.copy_(k_weight.contiguous())
102+
attn.v_proj.weight.data.copy_(v_weight.contiguous())
103+
104+
# Handle biases if they exist (same grouped layout)
105+
if wqkv_bias is not None:
106+
bias_grouped = wqkv_bias.reshape(num_kv_heads, group_size, head_dim)
107+
108+
q_bias = bias_grouped[:, :num_kv_groups, :].reshape(q_size)
109+
k_bias = bias_grouped[:, -2, :].reshape(kv_size)
110+
v_bias = bias_grouped[:, -1, :].reshape(kv_size)
111+
112+
attn.q_proj.bias.data.copy_(q_bias.contiguous())
113+
attn.k_proj.bias.data.copy_(k_bias.contiguous())
114+
attn.v_proj.bias.data.copy_(v_bias.contiguous())
115+
116+
# Map wo to o_proj
117+
if hasattr(attn, 'wo') and not hasattr(attn, 'o_proj'):
118+
attn.o_proj = attn.wo
119+
120+
return model

0 commit comments

Comments
 (0)