Skip to content

Commit 344a1a1

Browse files
authored
✨ support granite 4 dense (#635)
# Description Adds support for configuring the unreleased granite 4 dense models the same way as the granite 3 models, so they work out of the box --------- Signed-off-by: Joe Runde <joe@joerun.de>
1 parent 566b786 commit 344a1a1

File tree

6 files changed

+142
-20
lines changed

6 files changed

+142
-20
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
{
2+
"architectures": [
3+
"GraniteMoeHybridForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"attention_multiplier": 0.0078125,
8+
"bos_token_id": 100257,
9+
"embedding_multiplier": 12,
10+
"eos_token_id": 100257,
11+
"hidden_act": "silu",
12+
"hidden_size": 4096,
13+
"init_method": "mup",
14+
"initializer_range": 0.1,
15+
"intermediate_size": 12800,
16+
"layer_types": [
17+
"attention",
18+
"attention",
19+
"attention",
20+
"attention",
21+
"attention",
22+
"attention",
23+
"attention",
24+
"attention",
25+
"attention",
26+
"attention",
27+
"attention",
28+
"attention",
29+
"attention",
30+
"attention",
31+
"attention",
32+
"attention",
33+
"attention",
34+
"attention",
35+
"attention",
36+
"attention",
37+
"attention",
38+
"attention",
39+
"attention",
40+
"attention",
41+
"attention",
42+
"attention",
43+
"attention",
44+
"attention",
45+
"attention",
46+
"attention",
47+
"attention",
48+
"attention",
49+
"attention",
50+
"attention",
51+
"attention",
52+
"attention",
53+
"attention",
54+
"attention",
55+
"attention",
56+
"attention"
57+
],
58+
"logits_scaling": 16,
59+
"mamba_chunk_size": 256,
60+
"mamba_conv_bias": true,
61+
"mamba_d_conv": 4,
62+
"mamba_d_head": 64,
63+
"mamba_d_state": 256,
64+
"mamba_expand": 2,
65+
"mamba_n_groups": 1,
66+
"mamba_n_heads": 128,
67+
"mamba_proj_bias": false,
68+
"max_position_embeddings": 131072,
69+
"model_type": "granitemoehybrid",
70+
"normalization_function": "rmsnorm",
71+
"num_attention_heads": 32,
72+
"num_experts_per_tok": 0,
73+
"num_hidden_layers": 40,
74+
"num_key_value_heads": 8,
75+
"num_local_experts": 0,
76+
"output_router_logits": false,
77+
"pad_token_id": 100256,
78+
"position_embedding_type": "rope",
79+
"residual_multiplier": 0.22,
80+
"rms_norm_eps": 1e-05,
81+
"rope_scaling": null,
82+
"rope_theta": 10000000,
83+
"router_aux_loss_coef": 0.01,
84+
"shared_intermediate_size": 12800,
85+
"tie_word_embeddings": true,
86+
"torch_dtype": "bfloat16",
87+
"transformers_version": "4.56.0",
88+
"use_cache": true,
89+
"vocab_size": 100352
90+
}

tests/fixtures/model_configs/ibm-granite/granite4/config.json

Lines changed: 0 additions & 6 deletions
This file was deleted.

tests/models/test_granite.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,35 @@ def test_granite_3_8b_detection():
3838

3939
assert not SpyrePlatform.is_granite_3_8b(granite_micro_config.model_config)
4040

41+
assert not SpyrePlatform.is_granite_4_8b_dense(granite_3_8b_config.model_config)
42+
43+
44+
@pytest.mark.cpu
45+
def test_granite_4_dense_detection():
46+
"""Check that we can detect the model config for granite 4 8b (dense)"""
47+
48+
granite_4_dense_config = VllmConfig(
49+
model_config=ModelConfig(model=str(FIXTURES_PATH / "ibm-granite" / "granite-4-8b-dense")),
50+
cache_config=NO_SWAP_CONFIG(),
51+
)
52+
53+
assert SpyrePlatform.is_granite_4_8b_dense(granite_4_dense_config.model_config)
54+
assert not SpyrePlatform.is_granite_3_8b(granite_4_dense_config.model_config)
55+
4156

4257
@pytest.mark.cpu
4358
@pytest.mark.parametrize(
44-
"sendnn_configured, sendnn_version, expected_blocks",
59+
"model_name, sendnn_configured, sendnn_version, expected_blocks",
4560
[
46-
(True, (0, 0, 0), 8192),
47-
(True, (1, 0, 2), 2080),
48-
(True, (1, 1, 0), 8192),
49-
(False, (1, 0, 2), 8192),
61+
("granite-3.3-8b-instruct", True, (0, 0, 0), 8192),
62+
("granite-3.3-8b-instruct", True, (1, 0, 2), 2080),
63+
("granite-3.3-8b-instruct", True, (1, 1, 0), 8192),
64+
("granite-3.3-8b-instruct", False, (1, 0, 2), 8192),
65+
("granite-4-8b-dense", True, (1, 1, 0), 8192),
5066
],
5167
ids=lambda vals: f"{vals}",
5268
)
53-
def test_granite_3_8b_overrides(sendnn_configured, sendnn_version, expected_blocks):
69+
def test_granite_overrides(model_name, sendnn_configured, sendnn_version, expected_blocks):
5470
"""Check that the correct values are overridden for g3.3 8b"""
5571

5672
# Must ensure no env vars have been overridden before testing
@@ -64,9 +80,7 @@ def test_granite_3_8b_overrides(sendnn_configured, sendnn_version, expected_bloc
6480
tp4_config = ParallelConfig(tensor_parallel_size=4)
6581

6682
granite_3_8b_config = VllmConfig(
67-
model_config=ModelConfig(
68-
model=str(FIXTURES_PATH / "ibm-granite" / "granite-3.3-8b-instruct")
69-
),
83+
model_config=ModelConfig(model=str(FIXTURES_PATH / "ibm-granite" / model_name)),
7084
parallel_config=tp4_config,
7185
cache_config=NO_SWAP_CONFIG(),
7286
)

vllm_spyre/config/known_model_configs.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
"format": "float-quantized"
2424
}
2525
},
26-
"ibm-granite/granite4": {
27-
"model_type": "granitemoehybrid"
26+
"ibm-granite/granite-4-8b-dense": {
27+
"model_type": "granitemoehybrid",
28+
"vocab_size": 100352,
29+
"num_experts_per_tok": 0
2830
},
2931
"ibm-granite/granite-embedding-125m-english": {
3032
"model_type": "roberta",

vllm_spyre/config/supported_configs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
{ cb: True, tp_size: 4, max_model_len: 16384, max_num_seqs: 4 },
2424
{ cb: True, tp_size: 4, max_model_len: 32768, max_num_seqs: 32 },
2525
]
26-
- model: "ibm-granite/granite4"
26+
- model: "ibm-granite/granite-4-8b-dense"
2727
configs: [
2828
{ cb: False, tp_size: 1, warmup_shapes: [[2048, 1024, 16]] },
2929
{ cb: False, tp_size: 4, warmup_shapes: [[6144, 2048, 1]] },

vllm_spyre/platform.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import sys
22

3+
from transformers import GraniteMoeHybridConfig
4+
35
# When running this plugin on a Mac, we assume it's for local development
46
# purposes. However, due to a compatibility issue with vLLM, which overrides
57
# the Triton module with a placeholder, vLLM may fail to load on macOS. To
@@ -206,7 +208,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
206208
)
207209

208210
# Hardcode some things for granite-3.3-8b-instruct
209-
if cls.is_granite_3_8b(vllm_config.model_config):
211+
if cls.is_granite_3_8b(vllm_config.model_config) or cls.is_granite_4_8b_dense(
212+
vllm_config.model_config
213+
):
210214
cls.configure_granite_3_8b(vllm_config)
211215

212216
# To disable any paged attention ops in the base scheduler, we:
@@ -731,7 +735,7 @@ def is_granite_3_8b(cls, model_config: ModelConfig):
731735
"""Returns true if we have a model that looks like
732736
ibm-granite/granite-3.3-8b-instruct"""
733737
if not isinstance(model_config.hf_config, GraniteConfig):
734-
# Not granite at all
738+
# Not granite 3 at all
735739
return False
736740

737741
return (
@@ -743,6 +747,24 @@ def is_granite_3_8b(cls, model_config: ModelConfig):
743747
and model_config.hf_config.num_attention_heads == 32
744748
)
745749

750+
@classmethod
751+
def is_granite_4_8b_dense(cls, model_config: ModelConfig):
752+
"""Returns true if we have a dense granite 4 model with the same architecture as granite 3.3
753+
8b"""
754+
if not isinstance(model_config.hf_config, GraniteMoeHybridConfig):
755+
# Not granite 4 at all
756+
return False
757+
758+
return (
759+
model_config.hf_config.num_hidden_layers == 40
760+
and model_config.hf_config.num_experts_per_tok == 0 # dense model
761+
and model_config.hf_config.max_position_embeddings == 131072
762+
and model_config.hf_config.hidden_size == 4096
763+
and model_config.hf_config.vocab_size == 100352
764+
and model_config.hf_config.num_key_value_heads == 8
765+
and model_config.hf_config.num_attention_heads == 32
766+
)
767+
746768
@classmethod
747769
def sendnn_configured(cls) -> bool:
748770
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn":

0 commit comments

Comments
 (0)