Skip to content

Commit 2f02816

Browse files
authored
[None][feat] AutoDeploy: Onboard google/gemma-4-31B-it dense model, including nvfp4 (NVIDIA#12866)
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
1 parent ce80c14 commit 2f02816

File tree

7 files changed

+616
-34
lines changed

7 files changed

+616
-34
lines changed

docs/source/models/supported-models.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The following is a table of supported models for the PyTorch backend:
1414
| `ExaoneMoEForCausalLM` | K-EXAONE | `LGAI-EXAONE/K-EXAONE-236B-A23B` |
1515
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it` |
1616
| `Gemma3nForConditionalGeneration` [^8]| Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it` |
17-
| `Gemma4ForConditionalGeneration` [^7]| Gemma 4 | `google/gemma-4-26B-A4B-it` |
17+
| `Gemma4ForConditionalGeneration` [^7]| Gemma 4 | `google/gemma-4-26B-A4B-it`, `google/gemma-4-31B-it` |
1818
| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6, GLM-4.7 | `THUDM/GLM-4-100B-A10B` |
1919
| `Glm4MoeLiteForCausalLM` [^6] | GLM-4.7-Flash | `zai-org/GLM-4.7-Flash` |
2020
| `GlmMoeDsaForCausalLM` | GLM-5 | `zai-org/GLM-5` |
@@ -62,7 +62,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
6262
[^4]: Overlap scheduler isn't supported when using EAGLE-3(Two Model Engine) for GPT-OSS.
6363
[^5]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml).
6464
[^6]: Supported via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/glm-4.7-flash.yaml).
65-
[^7]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma4_moe.yaml).
65+
[^7]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See AD configs for [MoE](../../../examples/auto_deploy/model_registry/configs/gemma4_moe.yaml) and [dense](../../../examples/auto_deploy/model_registry/configs/gemma4_dense.yaml).
6666
[^8]: Text-only support via the [AutoDeploy](../features/auto_deploy/auto-deploy.md) backend. See [AD config](../../../examples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yaml).
6767

6868

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Gemma 4 dense (31B) — text-only AD export path.
5+
# Uses triton paged attention backend: supports head_dim=512 (global_head_dim),
6+
# paged KV cache, CUDA-graph-compatible, FlashDecoding for decode.
7+
model_factory: Gemma4ForConditionalGeneration
8+
tokenizer: google/gemma-4-31B-it
9+
attn_backend: triton_paged
10+
compile_backend: torch-cudagraph
11+
cuda_graph_config:
12+
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
13+
max_num_tokens: 8192
14+
max_batch_size: 512
15+
max_seq_len: 8192
16+
enable_chunked_prefill: true
17+
kv_cache_config:
18+
enable_block_reuse: false
19+
free_gpu_memory_fraction: 0.8
20+
transforms:
21+
compile_model:
22+
piecewise_enabled: true
23+
mlir_elementwise_fusion:
24+
enabled: true
25+
gather_logits_before_lm_head:
26+
enabled: true
27+
fuse_gemms:
28+
enabled: true

examples/auto_deploy/model_registry/models.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ models:
315315
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe_base.yaml']
316316
- name: google/gemma-4-26B-A4B-it
317317
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml', 'gemma4_moe.yaml']
318+
# --- Gemma 4 (2026) - Dense 31B ---
319+
- name: google/gemma-4-31B-it
320+
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'gemma4_dense.yaml']
318321
# --- JetBrains Mellum (Apr 2025) - code specialist ---
319322
- name: JetBrains/Mellum-4b-sft-all
320323
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,14 @@ def _flash_decode_stage1_kernel(
347347
)
348348
page_mask_2d = page_mask[:, None]
349349

350-
k = tl.load(
351-
kv_cache_ptr + cache_base, mask=page_mask_2d, other=0.0
352-
) # [PAGE_SIZE, HEAD_DIM]
350+
k = tl.load(kv_cache_ptr + cache_base, mask=page_mask_2d, other=0.0).to(
351+
q_all.dtype
352+
) # [PAGE_SIZE, HEAD_DIM]; cast from fp8 if kv cache is fp8
353353
v = tl.load(
354354
kv_cache_ptr + cache_base + cache_stride_kv,
355355
mask=page_mask_2d,
356356
other=0.0,
357-
) # [PAGE_SIZE, HEAD_DIM]
357+
).to(q_all.dtype) # [PAGE_SIZE, HEAD_DIM]; cast from fp8 if kv cache is fp8
358358

359359
# [HEAD_RATIO_PADDED, HEAD_DIM] @ [HEAD_DIM, PAGE_SIZE] -> [HEAD_RATIO_PADDED, PAGE_SIZE]
360360
attn = tl.dot(q_all, tl.trans(k)) * SM_SCALE
@@ -728,12 +728,12 @@ def _paged_context_kernel(
728728
kv_cache_ptr + page_base + local_kv,
729729
mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1),
730730
other=0.0,
731-
)
731+
).to(q.dtype) # cast from fp8 if kv cache is fp8
732732
v = tl.load(
733733
kv_cache_ptr + page_base + local_kv + cache_stride_kv,
734734
mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1),
735735
other=0.0,
736-
)
736+
).to(q.dtype) # cast from fp8 if kv cache is fp8
737737

738738
qk = tl.dot(q, tl.trans(k)) * SM_SCALE
739739

@@ -745,24 +745,16 @@ def _paged_context_kernel(
745745
full_mask_p1 = q_mask[:, None] & sw_mask
746746
qk = tl.where(full_mask_p1, qk, float("-inf"))
747747
else:
748-
k_block_ptr = tl.make_block_ptr(
749-
base=kv_cache_ptr + page_base,
750-
shape=(PAGE_SIZE, HEAD_DIM),
751-
strides=(cache_stride_token, 1),
752-
offsets=(0, 0),
753-
block_shape=(PAGE_SIZE, HEAD_DIM),
754-
order=(1, 0),
755-
)
756-
v_block_ptr = tl.make_block_ptr(
757-
base=kv_cache_ptr + page_base + cache_stride_kv,
758-
shape=(PAGE_SIZE, HEAD_DIM),
759-
strides=(cache_stride_token, 1),
760-
offsets=(0, 0),
761-
block_shape=(PAGE_SIZE, HEAD_DIM),
762-
order=(1, 0),
763-
)
764-
k = tl.load(k_block_ptr)
765-
v = tl.load(v_block_ptr)
748+
k = tl.load(
749+
kv_cache_ptr + page_base + local_kv,
750+
mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1),
751+
other=0.0,
752+
).to(q.dtype) # cast from fp8 if kv cache is fp8
753+
v = tl.load(
754+
kv_cache_ptr + page_base + local_kv + cache_stride_kv,
755+
mask=tl.full([PAGE_SIZE, HEAD_DIM], 1, tl.int1),
756+
other=0.0,
757+
).to(q.dtype) # cast from fp8 if kv cache is fp8
766758

767759
qk = tl.dot(q, tl.trans(k)) * SM_SCALE
768760

@@ -799,12 +791,14 @@ def _paged_context_kernel(
799791
# Use int64 to avoid overflow when physical_page * stride > 2^31
800792
page_base = physical_page.to(tl.int64) * cache_stride_block + kv_head_offset
801793
page_mask_2d = page_mask[:, None]
802-
k = tl.load(kv_cache_ptr + page_base + local_kv, mask=page_mask_2d, other=0.0)
794+
k = tl.load(kv_cache_ptr + page_base + local_kv, mask=page_mask_2d, other=0.0).to(
795+
q.dtype
796+
) # cast from fp8 if kv cache is fp8
803797
v = tl.load(
804798
kv_cache_ptr + page_base + local_kv + cache_stride_kv,
805799
mask=page_mask_2d,
806800
other=0.0,
807-
)
801+
).to(q.dtype) # cast from fp8 if kv cache is fp8
808802

809803
qk = tl.dot(q, tl.trans(k)) * SM_SCALE
810804
kv_positions = kv_base_pos + page_offsets[None, :]
@@ -938,11 +932,24 @@ def triton_paged_context(
938932

939933
max_pages = (max_q_len + page_size - 1) // page_size
940934
total_expected_pages = num_seq * max_pages
935+
# Force SDPA for large head_dim: the Triton paged kernel's tl.dot produces
936+
# misaligned shared memory accesses on Blackwell when HEAD_DIM > 256.
937+
large_head_dim = head_dim > 256
938+
# kv_indices may be a pre-allocated buffer larger than the actual page count;
939+
# fall back to the page table indptr which always reflects the true count.
940+
pages_uniform = kv_indices.shape[0] == total_expected_pages or (
941+
max_pages > 0 and int(kv_indptr[-1].item()) == total_expected_pages
942+
)
943+
# SDPA reshape requires all sequences to have the same q_len (since q is
944+
# packed as [total_tokens, ...] and we reshape to [num_seq, max_q_len, ...]).
945+
# Check without GPU sync: sum(q_len_i) == num_seq * max_q_len iff all equal.
946+
all_same_q_len = total_tokens == num_seq * max_q_len
941947
use_sdpa = (
942-
max_q_len >= 512
943-
and num_seq <= 64
948+
(max_q_len >= 512 or large_head_dim)
949+
and (num_seq <= 64 or large_head_dim)
944950
and max_pages > 0
945-
and kv_indices.shape[0] == total_expected_pages # all seqs same page count
951+
and pages_uniform
952+
and all_same_q_len
946953
and sw == 0 # SDPA doesn't support sliding window natively
947954
)
948955

@@ -979,6 +986,11 @@ def triton_paged_context(
979986
HEAD_DIM=head_dim,
980987
)
981988

989+
# Cast k/v to query dtype if kv cache uses a different dtype (e.g., fp8)
990+
if kv_cache.dtype != q.dtype:
991+
k_sdpa = k_sdpa.to(q.dtype)
992+
v_sdpa = v_sdpa.to(q.dtype)
993+
982994
# SDPA with GQA
983995
o_sdpa = torch.nn.functional.scaled_dot_product_attention(
984996
q.view(num_seq, max_q_len, n_heads, head_dim).transpose(1, 2),

tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,13 @@ def has(cls, reader_cls: str) -> bool:
8585

8686
@QuantConfigReaderRegistry.register("modelopt")
8787
class ModelOPTQuantConfigReader(QuantConfigReader):
88-
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens", "*.mixer.gate*", "*.mlp.gate")
88+
_ALWAYS_EXCLUDE = (
89+
"lm_head",
90+
"model.embed_tokens",
91+
"*.embed_tokens",
92+
"*.mixer.gate*",
93+
"*.mlp.gate",
94+
)
8995
DEFAULT_TORCH_DTYPE = "float16"
9096
DEFAULT_KV_CACHE_DTYPE = "fp8"
9197

tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,55 @@ def _small_text_config() -> Gemma4TextConfig:
107107
return config
108108

109109

110+
def _small_dense_text_config() -> Gemma4TextConfig:
111+
"""Small config mimicking gemma-4-31B-it (dense, no MoE)."""
112+
config = Gemma4TextConfig(
113+
vocab_size=256,
114+
hidden_size=64,
115+
intermediate_size=128,
116+
num_hidden_layers=3,
117+
num_attention_heads=4,
118+
num_key_value_heads=2,
119+
num_global_key_value_heads=1,
120+
head_dim=16,
121+
global_head_dim=32,
122+
hidden_activation="gelu_pytorch_tanh",
123+
max_position_embeddings=64,
124+
rms_norm_eps=1e-6,
125+
attention_bias=False,
126+
attention_dropout=0.0,
127+
attention_k_eq_v=True,
128+
sliding_window=16,
129+
layer_types=["sliding_attention", "sliding_attention", "full_attention"],
130+
enable_moe_block=False,
131+
num_experts=None,
132+
top_k_experts=None,
133+
expert_intermediate_size=None,
134+
final_logit_softcapping=30.0,
135+
hidden_size_per_layer_input=0,
136+
num_kv_shared_layers=0,
137+
use_double_wide_mlp=False,
138+
use_bidirectional_attention="vision",
139+
rope_parameters={
140+
"full_attention": {
141+
"rope_type": "proportional",
142+
"rope_theta": 1000000.0,
143+
"partial_rotary_factor": 0.25,
144+
},
145+
"sliding_attention": {
146+
"rope_type": "default",
147+
"rope_theta": 10000.0,
148+
},
149+
},
150+
pad_token_id=0,
151+
eos_token_id=1,
152+
bos_token_id=2,
153+
tie_word_embeddings=True,
154+
)
155+
config._attn_implementation = "eager"
156+
return config
157+
158+
110159
def _position_ids(batch_size: int, seq_len: int, device: str) -> torch.Tensor:
111160
return torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
112161

@@ -695,3 +744,133 @@ def test_export():
695744
logits2 = out2[0] if isinstance(out2, tuple) else getattr(out2, "logits", out2)
696745
assert logits2.shape == (B2, S2, config.vocab_size)
697746
assert torch.isfinite(logits2).all()
747+
748+
749+
# ---------------------------------------------------------------------------
750+
# Tests — Dense variant (gemma-4-31B-it style, no MoE)
751+
# ---------------------------------------------------------------------------
752+
753+
754+
def test_dense_decoder_layer_equivalence():
755+
"""Dense (non-MoE) decoder layer matches reference for sliding and full attention."""
756+
device, dtype = _device_and_dtype()
757+
config = _small_dense_text_config()
758+
759+
for layer_idx in [0, 2]:
760+
layer_type = config.layer_types[layer_idx]
761+
ref = _RefDecoderLayer(config, layer_idx).to(device=device, dtype=dtype).eval()
762+
ad = Gemma4TextDecoderLayer(config, layer_idx).to(device=device, dtype=dtype).eval()
763+
_load_ref_into_ad(ad, ref)
764+
765+
B, S = 2, 8
766+
x = torch.randn(B, S, config.hidden_size, device=device, dtype=dtype)
767+
pos_ids = _position_ids(B, S, device)
768+
rope = _build_ref_rope(config, layer_type, device, dtype)
769+
cos, sin = rope(x, pos_ids)
770+
771+
causal_mask = (
772+
torch.triu(torch.full((S, S), float("-inf"), device=device, dtype=dtype), diagonal=1)
773+
.unsqueeze(0)
774+
.unsqueeze(0)
775+
)
776+
777+
with torch.no_grad():
778+
ad_out = ad(x, (cos, sin))
779+
ref_out = ref(x, (cos, sin), attention_mask=causal_mask)
780+
assert_rmse_close(
781+
ad_out,
782+
ref_out,
783+
rmse_ratio_tol=0.05,
784+
msg=f"Dense layer {layer_idx} ({layer_type}): ",
785+
)
786+
787+
788+
def test_dense_full_model_equivalence():
789+
"""Dense CausalLM logits (no MoE) match reference."""
790+
device, dtype = _device_and_dtype()
791+
config = _small_dense_text_config()
792+
793+
ref = _RefForCausalLM(config).to(device=device, dtype=dtype).eval()
794+
ad = Gemma4ForCausalLM(config).to(device=device, dtype=dtype).eval()
795+
_transfer_ref_to_ad_full_model(ad, ref)
796+
797+
B, S = 2, 8
798+
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
799+
pos_ids = _position_ids(B, S, device)
800+
801+
with torch.no_grad():
802+
ref_logits = ref(input_ids, pos_ids)
803+
ad_out = ad(input_ids=input_ids, position_ids=pos_ids)
804+
805+
assert ad_out.logits.shape == (B, S, config.vocab_size)
806+
assert torch.isfinite(ad_out.logits).all()
807+
assert_rmse_close(ad_out.logits, ref_logits, rmse_ratio_tol=0.05, msg="Dense full model: ")
808+
809+
810+
def test_dense_conditional_generation_wrapper():
811+
"""ConditionalGeneration wrapper works with dense (non-MoE) text config."""
812+
device, dtype = _device_and_dtype()
813+
config = Gemma4Config(
814+
text_config=_small_dense_text_config(),
815+
vision_config=Gemma4VisionConfig(hidden_size=32),
816+
)
817+
model = Gemma4ForConditionalGeneration(config).to(device=device, dtype=dtype).eval()
818+
819+
B, S = 2, 8
820+
input_ids = torch.randint(0, config.text_config.vocab_size, (B, S), device=device)
821+
pos_ids = _position_ids(B, S, device)
822+
823+
with torch.no_grad():
824+
out = model(input_ids=input_ids, position_ids=pos_ids)
825+
assert out.logits is not None
826+
assert out.logits.shape == (B, S, config.text_config.vocab_size)
827+
assert torch.isfinite(out.logits).all()
828+
829+
830+
def test_dense_export():
831+
"""Dense model (no MoE) can be exported with torch.export."""
832+
device = "cpu"
833+
dtype = torch.float32
834+
config = _small_dense_text_config()
835+
836+
model = Gemma4ForCausalLM(config).to(device=device, dtype=dtype).eval()
837+
838+
B, S = 2, 8
839+
input_ids = torch.randint(0, config.vocab_size, (B, S), device=device)
840+
pos_ids = _position_ids(B, S, device)
841+
842+
batch_dim = Dim("batch", min=1, max=4)
843+
seq_dim = Dim("seq", min=1, max=64)
844+
dynamic_shapes = {
845+
"input_ids": {0: batch_dim, 1: seq_dim},
846+
"position_ids": {0: batch_dim, 1: seq_dim},
847+
}
848+
849+
gm = torch_export_to_gm(
850+
model,
851+
args=(input_ids,),
852+
kwargs={"position_ids": pos_ids},
853+
dynamic_shapes=dynamic_shapes,
854+
)
855+
856+
with torch.no_grad():
857+
pre_export_out = model(input_ids=input_ids, position_ids=pos_ids)
858+
exported_out = gm(input_ids, position_ids=pos_ids)
859+
860+
logits = (
861+
exported_out[0]
862+
if isinstance(exported_out, tuple)
863+
else getattr(exported_out, "logits", exported_out)
864+
)
865+
assert torch.isfinite(logits).all(), "Dense export produced non-finite values"
866+
torch.testing.assert_close(logits, pre_export_out.logits, rtol=1e-3, atol=1e-3)
867+
868+
# Test different shape
869+
B2, S2 = 1, 4
870+
ids2 = torch.randint(0, config.vocab_size, (B2, S2), device=device)
871+
pos2 = _position_ids(B2, S2, device)
872+
with torch.no_grad():
873+
out2 = gm(ids2, position_ids=pos2)
874+
logits2 = out2[0] if isinstance(out2, tuple) else getattr(out2, "logits", out2)
875+
assert logits2.shape == (B2, S2, config.vocab_size)
876+
assert torch.isfinite(logits2).all()

0 commit comments

Comments
 (0)