Skip to content

Commit 471958b

Browse files
Add GraniteMoeHybrid support for 4.0 (#37658)
* initial config and MLA layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * first pass at decoder Signed-off-by: Sukriti-Sharma4 <[email protected]> * completion of layers Signed-off-by: Sukriti-Sharma4 <[email protected]> * modeling class Signed-off-by: Sukriti-Sharma4 <[email protected]> * adding hybrid class to imports Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix imports granitemoehybrid Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix granitehybrid imports Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix granitehybrid import Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix generated modeling file Signed-off-by: Sukriti-Sharma4 <[email protected]> * add some comments Signed-off-by: Sukriti-Sharma4 <[email protected]> * minor fixes in layers Signed-off-by: Sukriti-Sharma4 <[email protected]> * add sharedMLP layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * correct layer names Signed-off-by: Sukriti-Sharma4 <[email protected]> * fixes in mamba config Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix mamba config Signed-off-by: Sukriti-Sharma4 <[email protected]> * change name of MLP layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix seq mizer layers Signed-off-by: Sukriti-Sharma4 <[email protected]> * correct mamba config Signed-off-by: Sukriti-Sharma4 <[email protected]> * fixes in param names Signed-off-by: Sukriti-Sharma4 <[email protected]> * enable hybrid model Signed-off-by: Sukriti-Sharma4 <[email protected]> * update config Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix config granite hybrid Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix attention layer Signed-off-by: Sukriti-Sharma4 <[email protected]> * cleanup to re-use mamba code Signed-off-by: Sukriti-Sharma4 <[email protected]> * keep layer types Signed-off-by: Sukriti-Sharma4 <[email protected]> * attention bias cleanup Signed-off-by: Sukriti-Sharma4 <[email protected]> * update mamba layer name Signed-off-by: Sukriti-Sharma4 <[email protected]> * first pass at tests Signed-off-by: Sukriti-Sharma4 <[email protected]> * first pass at tests Signed-off-by: Sukriti-Sharma4 <[email protected]> * use granite attention Signed-off-by: Sukriti-Sharma4 <[email protected]> * fix: self attn weights Signed-off-by: Sukriti-Sharma4 <[email protected]> * pass at making pos_emb optional Signed-off-by: Sukriti-Sharma4 <[email protected]> * initialize self_attn only as needed Signed-off-by: Sukriti-Sharma4 <[email protected]> * overwrite forward to create HybridMambaCache Signed-off-by: Sukriti-Sharma4 <[email protected]> * Log invalid layer types * Add attention outputs test * Only emit attentions/logits if not None * Fix config test hidden size divisibility * mark granitmoehybrid as stateful * Initialize mamba convolutional layers * Formatting fixes * config docstring, removed some unused attrs * Fix missing arg in models test * Fix create and check decoder model test * support logits to keep in granitemoe * regen to pass logits_to_keep * Allow None or rope * Fix gradient checkpointing * Add granitemoehybrid as special cache for generate check * Remove unused MLA refs * Fix mamba layer mask * Remove logits to keep from config * Minor docstring nits * Update licenses * Enable cache by default * map layer types to layer block type * First pass at granite moe hybrid docs * Ignore granite moe hybrid in valid checkpoint check * Align attention interfaces * regenerate modular granitemoeshared attention interface * Align granite moe hybrid attn interface * run formatting * Handle mamba initialization * avoid conditional attr defs * Move hybrid layer validation to config * Add placeholder integration tests * Docs nits / Update model names * Clean up forward conditions * Use gradient checkpointing layer * Remove some copied bamba tests + inherit align test init delete more tests Use common layer init with bamba tests finish test consolidation * avoid redundant intermediate std var * use @can_return_tuple * Remove unused moe state * make skipped test names consistent * Fix docstring order * Add missing toc * Always create the shared mlp * Fix name in docstring * link preview model in docs --------- Signed-off-by: Sukriti-Sharma4 <[email protected]> Co-authored-by: Alex-Brooks <[email protected]>
1 parent fe29b8c commit 471958b

21 files changed

+3150
-544
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@
495495
title: Granite
496496
- local: model_doc/granitemoe
497497
title: GraniteMoe
498+
- local: model_doc/granitemoehybrid
499+
title: GraniteMoeHybrid
498500
- local: model_doc/granitemoeshared
499501
title: GraniteMoeShared
500502
- local: model_doc/helium
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# GraniteMoeHybrid
18+
19+
## Overview
20+
21+
22+
The `GraniteMoeHybrid` model builds on top of `GraniteMoeSharedModel` and `Bamba`. Its decoding layers consist of state space layers or MoE attention layers with shared experts. By default, the attention layers do not use positional encoding.
23+
24+
25+
```python
26+
from transformers import AutoModelForCausalLM, AutoTokenizer
27+
28+
model_path = "ibm-granite/granite-4.0-tiny-preview"
29+
tokenizer = AutoTokenizer.from_pretrained(model_path)
30+
31+
# drop device_map if running on CPU
32+
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
33+
model.eval()
34+
35+
# change input text as desired
36+
prompt = "Write a code to find the maximum value in a list of numbers."
37+
38+
# tokenize the text
39+
input_tokens = tokenizer(prompt, return_tensors="pt")
40+
# generate output tokens
41+
output = model.generate(**input_tokens, max_new_tokens=100)
42+
# decode output tokens into text
43+
output = tokenizer.batch_decode(output)
44+
# loop over the batch to print, in this example the batch size is 1
45+
for i in output:
46+
print(i)
47+
```
48+
49+
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
50+
51+
52+
## GraniteMoeHybridConfig
53+
54+
[[autodoc]] GraniteMoeHybridConfig
55+
56+
## GraniteMoeHybridModel
57+
58+
[[autodoc]] GraniteMoeHybridModel
59+
- forward
60+
61+
## GraniteMoeHybridForCausalLM
62+
63+
[[autodoc]] GraniteMoeHybridForCausalLM
64+
- forward

src/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
from .granite import *
130130
from .granite_speech import *
131131
from .granitemoe import *
132+
from .granitemoehybrid import *
132133
from .granitemoeshared import *
133134
from .grounding_dino import *
134135
from .groupvit import *

src/transformers/models/auto/configuration_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
("granite", "GraniteConfig"),
147147
("granite_speech", "GraniteSpeechConfig"),
148148
("granitemoe", "GraniteMoeConfig"),
149+
("granitemoehybrid", "GraniteMoeHybridConfig"),
149150
("granitemoeshared", "GraniteMoeSharedConfig"),
150151
("granitevision", "LlavaNextConfig"),
151152
("graphormer", "GraphormerConfig"),
@@ -509,6 +510,7 @@
509510
("granite", "Granite"),
510511
("granite_speech", "GraniteSpeech"),
511512
("granitemoe", "GraniteMoeMoe"),
513+
("granitemoehybrid", "GraniteMoeHybrid"),
512514
("granitemoeshared", "GraniteMoeSharedMoe"),
513515
("granitevision", "LLaVA-NeXT"),
514516
("graphormer", "Graphormer"),

src/transformers/models/auto/modeling_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
139139
("granite", "GraniteModel"),
140140
("granitemoe", "GraniteMoeModel"),
141+
("granitemoehybrid", "GraniteMoeHybridModel"),
141142
("granitemoeshared", "GraniteMoeSharedModel"),
142143
("graphormer", "GraphormerModel"),
143144
("grounding-dino", "GroundingDinoModel"),
@@ -558,6 +559,7 @@
558559
("gptj", "GPTJForCausalLM"),
559560
("granite", "GraniteForCausalLM"),
560561
("granitemoe", "GraniteMoeForCausalLM"),
562+
("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
561563
("granitemoeshared", "GraniteMoeSharedForCausalLM"),
562564
("helium", "HeliumForCausalLM"),
563565
("jamba", "JambaForCausalLM"),

src/transformers/models/bamba/modeling_bamba.py

+1
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def torch_forward(
854854
# Init cache
855855
if ssm_state is not None and cache_params is not None:
856856
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
857+
cache_params.has_previous_state = True
857858

858859
scan_output = self.norm(y, gate)
859860

src/transformers/models/bamba/modular_bamba.py

+1
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ def torch_forward(
651651
# Init cache
652652
if ssm_state is not None and cache_params is not None:
653653
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
654+
cache_params.has_previous_state = True
654655

655656
scan_output = self.norm(y, gate)
656657

src/transformers/models/granitemoe/configuration_granitemoe.py

+2
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def __init__(
166166
self.use_cache = use_cache
167167
self.rope_theta = rope_theta
168168
self.rope_scaling = rope_scaling
169+
# this model has rope embedding type, hardcoded for BC
170+
self.position_embedding_type = "rope"
169171

170172
self.attention_bias = attention_bias
171173
self.attention_dropout = attention_dropout

0 commit comments

Comments
 (0)