Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ See documentation at the [ONNX Runtime website](https://onnxruntime.ai/docs/gena

| Support matrix | Supported now | Under development | On the roadmap|
| -------------- | ------------- | ----------------- | -------------- |
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Fara <br/> Gemma <br/> gpt-oss <br/> Granite <br/> Llama <br/> Mistral <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen (language + vision) <br/> SmolLM3 <br/> Whisper | Stable diffusion | Multi-modal models |
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Fara <br/> Gemma <br/> gpt-oss <br/> Granite <br/> InternLM2 <br/> Llama <br/> Mistral <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen (language + vision) <br/> SmolLM3 <br/> Whisper | Stable diffusion | Multi-modal models |
| API| Python <br/>C# <br/>C/C++ <br/> Java ^ | Objective-C ||
| O/S | Linux <br/> Windows <br/>Mac <br/>Android || iOS |||
| Architecture | x86 <br/> x64 <br/> arm64 ||||
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;2fbe0ebbb3eb21199ab74c92b6edf3804d827998
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;087953cde6149e423c6848c40c3791264272706c

# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
llguidance;https://github.com/microsoft/llguidance.git;94fa39128ef184ffeda33845f6d333f332a34b4d
Expand Down
5 changes: 4 additions & 1 deletion src/models/model_type.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved.
// --------------------------------------------------------------------------

#pragma once

#include <algorithm>
Expand All @@ -12,7 +15,7 @@ namespace Generators {
struct ModelType {
inline static bool IsLLM(const std::string& model_type) {
// Large-language model (LLM)
static constexpr std::array<std::string_view, 20> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "gptoss", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
static constexpr std::array<std::string_view, 21> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "gptoss", "granite", "internlm2", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
return std::find(LLM.begin(), LLM.end(), model_type) != LLM.end();
}

Expand Down
2 changes: 2 additions & 0 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ The tool currently supports the following model architectures.
- Gemma
- gpt-oss
- Granite
- InternLM2
- Llama
- Mistral
- Nemotron
- Phi
- Qwen
- SmolLM3


It is intended for supporting the latest, popular state-of-the-art models.

## Usage
Expand Down
5 changes: 5 additions & 0 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content.
# --------------------------------------------------------------------------
"""
Run the model builder to create the desired ONNX model.
"""
Expand All @@ -22,6 +24,7 @@
GemmaModel,
GPTOSSModel,
GraniteModel,
InternLM2Model,
LlamaModel,
MistralModel,
Model,
Expand Down Expand Up @@ -244,6 +247,8 @@ def create_model(
onnx_model = GPTOSSModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "GraniteForCausalLM":
onnx_model = GraniteModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "InternLM2ForCausalLM":
onnx_model = InternLM2Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "LlamaForCausalLM":
onnx_model = LlamaModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "MistralForCausalLM":
Expand Down
4 changes: 4 additions & 0 deletions src/python/py/models/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content.
# --------------------------------------------------------------------------
from .base import Model
from .chatglm import ChatGLMModel
from .ernie import ErnieModel
from .gemma import Gemma2Model, Gemma3Model, GemmaModel
from .gptoss import GPTOSSModel
from .granite import GraniteModel
from .internlm import InternLM2Model
from .llama import LlamaModel
from .mistral import MistralModel
from .nemotron import NemotronModel
Expand All @@ -34,6 +37,7 @@
"Gemma3Model",
"GemmaModel",
"GraniteModel",
"InternLM2Model",
"LlamaModel",
"MistralModel",
"Model",
Expand Down
11 changes: 10 additions & 1 deletion src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content.
# --------------------------------------------------------------------------
from __future__ import annotations

Expand Down Expand Up @@ -673,6 +673,15 @@ def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
)
print(f"Saving processing files in {out_dir} for GenAI")
tokenizer.save_pretrained(out_dir)
# Overwrite model_max_length with the model's context_length so it is a normal integer
# (HF often uses 1e30 for "no limit", which can serialize to a huge decimal in JSON)
tokenizer_config_path = os.path.join(out_dir, "tokenizer_config.json")
if os.path.isfile(tokenizer_config_path):
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
config = json.load(f)
config["model_max_length"] = self.context_length
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)

def make_int4_algo_config(self, quant_method: str):
customized_weight_config = {}
Expand Down
120 changes: 120 additions & 0 deletions src/python/py/models/builders/internlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# -------------------------------------------------------------------------
# Copyright (C) [2026] Advanced Micro Devices, Inc. All rights reserved. Portions of this file consist of AI generated content
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from .base import Model
import torch.nn as nn


class InternLM2Model(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
# Export genai_config with type "internlm2" (C++ model_type.h already lists "internlm2" as LLM)
self.model_type = "InternLM2ForCausalLM"

def load_weights(self, input_path):
"""
Load the InternLM2 model and adapt attribute names to match base class expectations.
InternLM2 uses:
- attention_norm instead of input_layernorm
- ffn_norm instead of post_attention_layernorm
- feed_forward instead of mlp
- wqkv (combined QKV) instead of separate q_proj, k_proj, v_proj
- wo instead of o_proj
"""
# Load the model using the parent class method
model = super().load_weights(input_path)

# Get config from the loaded model
config = model.config

# Adapt each decoder layer to match the expected attribute names
for layer in model.model.layers:
# Map attention_norm to input_layernorm
if hasattr(layer, 'attention_norm') and not hasattr(layer, 'input_layernorm'):
layer.input_layernorm = layer.attention_norm

# Map ffn_norm to post_attention_layernorm
if hasattr(layer, 'ffn_norm') and not hasattr(layer, 'post_attention_layernorm'):
layer.post_attention_layernorm = layer.ffn_norm

# Map feed_forward to mlp
if hasattr(layer, 'feed_forward') and not hasattr(layer, 'mlp'):
layer.mlp = layer.feed_forward

# Map attention to self_attn
if hasattr(layer, 'attention') and not hasattr(layer, 'self_attn'):
layer.self_attn = layer.attention

# Map MLP projections (w1/w2/w3 to gate_proj/down_proj/up_proj)
if hasattr(layer.mlp, 'w1') and not hasattr(layer.mlp, 'gate_proj'):
layer.mlp.gate_proj = layer.mlp.w1
if hasattr(layer.mlp, 'w2') and not hasattr(layer.mlp, 'down_proj'):
layer.mlp.down_proj = layer.mlp.w2
if hasattr(layer.mlp, 'w3') and not hasattr(layer.mlp, 'up_proj'):
layer.mlp.up_proj = layer.mlp.w3

# Handle the combined wqkv projection in attention
# InternLM2 uses a grouped/interleaved layout: [Q1, Q2, ..., Qn, K, V] per KV group
# Layout: [batch, seq, num_kv_heads, (num_q_heads_per_kv_group + 2), head_dim]
if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'wqkv'):
attn = layer.self_attn
wqkv_weight = attn.wqkv.weight # Shape: [(num_heads + 2*num_kv_heads) * head_dim, hidden_size]
wqkv_bias = attn.wqkv.bias if hasattr(attn.wqkv, 'bias') and attn.wqkv.bias is not None else None

# Calculate dimensions
num_q_heads = config.num_attention_heads
num_kv_heads = config.num_key_value_heads
num_kv_groups = num_q_heads // num_kv_heads # How many Q heads per KV head
head_dim = config.hidden_size // num_q_heads

q_size = num_q_heads * head_dim
kv_size = num_kv_heads * head_dim

# InternLM2's wqkv is organized as interleaved groups:
# For each KV head group: [Q_heads for this group (num_kv_groups heads), K for this group, V for this group]
# We need to reshape and reorder to standard [all Q | all K | all V] layout

# Reshape to grouped format: [num_kv_heads, (num_kv_groups + 2), head_dim, hidden_size]
group_size = num_kv_groups + 2
wqkv_grouped = wqkv_weight.reshape(num_kv_heads, group_size, head_dim, config.hidden_size)

# Extract Q, K, V from grouped layout
# Q heads: first num_kv_groups entries in each group
q_weight = wqkv_grouped[:, :num_kv_groups, :, :].reshape(num_q_heads, head_dim, config.hidden_size)
q_weight = q_weight.reshape(q_size, config.hidden_size)

# K heads: second to last entry in each group
k_weight = wqkv_grouped[:, -2, :, :].reshape(kv_size, config.hidden_size)

# V heads: last entry in each group
v_weight = wqkv_grouped[:, -1, :, :].reshape(kv_size, config.hidden_size)

# Create separate projection layers
attn.q_proj = nn.Linear(config.hidden_size, q_size, bias=config.bias)
attn.k_proj = nn.Linear(config.hidden_size, kv_size, bias=config.bias)
attn.v_proj = nn.Linear(config.hidden_size, kv_size, bias=config.bias)

# Copy weights (ensure proper copy and contiguous memory)
attn.q_proj.weight.data.copy_(q_weight.contiguous())
attn.k_proj.weight.data.copy_(k_weight.contiguous())
attn.v_proj.weight.data.copy_(v_weight.contiguous())

# Handle biases if they exist (same grouped layout)
if wqkv_bias is not None:
bias_grouped = wqkv_bias.reshape(num_kv_heads, group_size, head_dim)

q_bias = bias_grouped[:, :num_kv_groups, :].reshape(q_size)
k_bias = bias_grouped[:, -2, :].reshape(kv_size)
v_bias = bias_grouped[:, -1, :].reshape(kv_size)

attn.q_proj.bias.data.copy_(q_bias.contiguous())
attn.k_proj.bias.data.copy_(k_bias.contiguous())
attn.v_proj.bias.data.copy_(v_bias.contiguous())

# Map wo to o_proj
if hasattr(attn, 'wo') and not hasattr(attn, 'o_proj'):
attn.o_proj = attn.wo

return model
Loading