Skip to content

Add support for InternLM2 model architecture#1958

Merged
kunal-vaishnavi merged 13 commits intomicrosoft:mainfrom
amdrajeevp1:add-internlm2-support
Feb 11, 2026
Merged

Add support for InternLM2 model architecture#1958
kunal-vaishnavi merged 13 commits intomicrosoft:mainfrom
amdrajeevp1:add-internlm2-support

Conversation

@amdrajeevp1
Copy link
Contributor

@amdrajeevp1 amdrajeevp1 commented Jan 30, 2026

Add InternLM2 Model Support

Adds full support for InternLM2 model family (1.8B, 7B, etc.) to ONNX Runtime GenAI.

Changes

Core Implementation

  • New InternLM2Model builder (src/python/py/models/builders/internlm.py)
    • Extends LlamaModel with InternLM2-specific weight mapping
    • GQA support: 16 query heads, 8 KV heads (2:1 ratio)
    • Proper grouped QKV weight splitting for GroupQueryAttention operator
  • Model registration (builder.py, __init__.py, model_type.h)
    • Maps InternLM2ForCausalLMInternLM2Model
    • Adds "internlm2" to supported model types

Tokenizer Support

  • Upstream: Contributed InternLM2Tokenizer support to onnxruntime-extensions#1023 (merged)
  • Dependencies:
    • Updated cmake/deps.txt to onnxruntime-extensions commit 087953cd
    • Removed local patch in cmake/external/onnxruntime_external_deps.cmake
  • Fix: Set correct model_max_length in tokenizer_config.json (prevents 1e30 invalid values)

Documentation

  • Updated README.md and src/python/py/models/README.md

Usage

Export

python -m onnxruntime_genai.models.builder \
--model_name internlm/internlm2-1_8b \
--output ./internlm2-cpu-int4 \
--precision int4 \
--execution_provider cpu

Inference

import onnxruntime_genai as og
model = og.Model("./internlm2-cpu-int4")
tokenizer = og.Tokenizer(model)
... standard generation code

Testing

  • ✅ InternLM2-1.8B INT4 CPU: export and inference
  • ✅ InternLM2-7B INT4 CPU: export tested
  • ✅ GQA weight splitting verified
  • ✅ Tokenizer recognition working

References

This commit adds support for exporting InternLM2 models to ONNX format.

Key changes:
- Add InternLM2Model class in src/python/py/models/builders/internlm.py
- Register InternLM2ForCausalLM architecture in builder.py
- Implement grouped/interleaved QKV weight splitting for GQA
- Map InternLM2-specific attribute names to base model equivalents
- Add documentation and example in examples/python/internlm2/

InternLM2 uses a Llama-based architecture with grouped query attention
and a unique grouped/interleaved QKV weight layout. The implementation
correctly handles this layout during weight extraction.

Tested with:
- InternLM2-1.8B (FP32, INT4 RTN, INT4 AWQ)
- Model generates coherent text and valid code

Model hub: https://huggingface.co/internlm/internlm2-1_8b
Paper: https://arxiv.org/abs/2403.17297
@kunal-vaishnavi
Copy link
Contributor

Thanks for your contribution! Can you also make the following additions for InternLM in alphabetical order?

  • Add the model type that gets generated in the genai_config.json here:
    static constexpr std::array<std::string_view, 20> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "gptoss", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
  • Add the model type to the model builder README
  • Add the model type to the repo's main README

- Add comprehensive MULTI_SIZE_SUPPORT.md documenting 1.8B/7B/20B compatibility
- Add export scripts for InternLM2-7B (Bash and PowerShell)
- Update README with model size comparison table
- Add hardware requirements and performance estimates
- Include GPU export examples for 7B model

The implementation is architecture-based and works with all InternLM2 sizes:
- Dynamically reads config parameters (heads, layers, dimensions)
- Adaptive weight splitting based on GQA ratios
- No hardcoded model sizes

Tested: InternLM2-1.8B
Compatible: InternLM2-7B, InternLM2-20B, all Chat variants
- Merge MULTI_SIZE_SUPPORT.md into README.md for single comprehensive guide
- Remove export_7b.ps1 and export_7b.sh scripts (examples already in README)
- Streamline documentation structure
- All export commands and multi-size information now in one place
- Add AMD copyright to builder.py (after Microsoft copyright)
- Add AMD copyright to builders/__init__.py (after Microsoft copyright)
- Update internlm.py with AMD copyright
- Add AMD copyright to README.md
https://huggingface.co/onnx-community/InternLM2-ONNX/

Signed-off-by: Rajeev Patwari <rajeevp@amd.com>
Signed-off-by: Rajeev Patwari <rajeevp@amd.com>
Signed-off-by: Rajeev Patwari <rajeevp@amd.com>
@amdrajeevp1
Copy link
Contributor Author

Thanks for your contribution! Can you also make the following additions for InternLM in alphabetical order?

  • Add the model type that gets generated in the genai_config.json here:
    static constexpr std::array<std::string_view, 20> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "gptoss", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
  • Add the model type to the model builder README
  • Add the model type to the repo's main README

Hi @kunal-vaishnavi - done!
Please review the PR. I have also uploaded the generated artifacts to this PR https://huggingface.co/onnx-community/InternLM2-ONNX/discussions/1

@amdrajeevp1
Copy link
Contributor Author

@microsoft-github-policy-service agree [company="AMD"]

@amdrajeevp1
Copy link
Contributor Author

@microsoft-github-policy-service agree company="AMD"

- Python builder: InternLM2Model in builders/internlm.py with HF->base
  name mapping and grouped wqkv split for GQA; export type 'internlm2'
- builder.py: register InternLM2ForCausalLM -> InternLM2Model
- builders/__init__.py: export InternLM2Model
- C++ model_type.h: add 'internlm2' to LLM list
- cmake: patch onnxruntime-extensions for InternLM2Tokenizer (BPE)
- base.py: set tokenizer_config model_max_length to context_length
- READMEs: list InternLM2 in supported models

Co-authored-by: Cursor <cursoragent@cursor.com>
- Bump extensions commit to 087953cd (includes InternLM2Tokenizer support)
- Remove local patch now that support is upstream

Ref: microsoft/onnxruntime-extensions#1023
Set tokenizer.model_max_length directly on the tokenizer object before calling save_pretrained(), eliminating the need to reopen and modify tokenizer_config.json after saving.

Co-authored-by: Cursor <cursoragent@cursor.com>
@kunal-vaishnavi kunal-vaishnavi enabled auto-merge (squash) February 11, 2026 04:10
Remove trailing space from AMD copyright line in model_type.h

Co-authored-by: Cursor <cursoragent@cursor.com>
auto-merge was automatically disabled February 11, 2026 04:22

Head branch was pushed to by a user without write access

@kunal-vaishnavi kunal-vaishnavi enabled auto-merge (squash) February 11, 2026 04:43
@kunal-vaishnavi kunal-vaishnavi merged commit a8fc81b into microsoft:main Feb 11, 2026
15 of 18 checks passed
baijumeswani pushed a commit that referenced this pull request Feb 12, 2026
# Add InternLM2 Model Support

Adds full support for InternLM2 model family (1.8B, 7B, etc.) to ONNX
Runtime GenAI.

## Changes

### Core Implementation
- **New InternLM2Model builder**
(`src/python/py/models/builders/internlm.py`)
  - Extends LlamaModel with InternLM2-specific weight mapping
  - GQA support: 16 query heads, 8 KV heads (2:1 ratio)
  - Proper grouped QKV weight splitting for GroupQueryAttention operator
- **Model registration** (`builder.py`, `__init__.py`, `model_type.h`)
  - Maps `InternLM2ForCausalLM` → `InternLM2Model`
  - Adds "internlm2" to supported model types

### Tokenizer Support
- **Upstream**: Contributed InternLM2Tokenizer support to
[onnxruntime-extensions#1023](microsoft/onnxruntime-extensions#1023)
(merged)
- **Dependencies**: 
  - Updated `cmake/deps.txt` to onnxruntime-extensions commit `087953cd`
- Removed local patch in
`cmake/external/onnxruntime_external_deps.cmake`
- **Fix**: Set correct `model_max_length` in tokenizer_config.json
(prevents 1e30 invalid values)

### Documentation
- Updated README.md and src/python/py/models/README.md

## Usage

Export
```
python -m onnxruntime_genai.models.builder \
--model_name internlm/internlm2-1_8b \
--output ./internlm2-cpu-int4 \
--precision int4 \
--execution_provider cpu
```

Inference
```
import onnxruntime_genai as og
model = og.Model("./internlm2-cpu-int4")
tokenizer = og.Tokenizer(model)
... standard generation code
```

## Testing
- ✅ InternLM2-1.8B INT4 CPU: export and inference
- ✅ InternLM2-7B INT4 CPU: export tested
- ✅ GQA weight splitting verified
- ✅ Tokenizer recognition working

## References
- Model: https://huggingface.co/internlm/internlm2-1_8b
- Upstream PR: microsoft/onnxruntime-extensions#1023

---------

Signed-off-by: Rajeev Patwari <rajeevp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants