Skip to content

Commit 79d2e37

Browse files
[TRTLLM-11770][feat] Skip nvfp4 fused norm if the dim doesn't meet the requirement (NVIDIA#12901)
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
1 parent b1385fe commit 79d2e37

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

tensorrt_llm/_torch/modules/mamba/layernorm_gated.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/layernorm_gated.py
22
# Copyright (c) 2024, Tri Dao, Albert Gu.
33
#
4-
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
55
# SPDX-License-Identifier: Apache-2.0
66
#
77
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,6 +23,23 @@
2323
from ...utils import Fp4QuantizedTensor
2424

2525

26+
def fused_gated_rmsnorm_quant_shape_ok(hidden_size: int,
27+
group_size: int) -> bool:
28+
"""True if ``torch.ops.trtllm.fused_gated_rmsnorm_quant`` supports this shape.
29+
30+
Keep in sync with TORCH_CHECKs in cpp/tensorrt_llm/thop/fusedGatedRMSNormQuant.cpp.
31+
"""
32+
if group_size <= 0 or hidden_size % group_size != 0:
33+
return False
34+
if group_size % 256 != 0:
35+
return False
36+
if not (256 <= group_size <= 8192):
37+
return False
38+
if hidden_size % 16 != 0:
39+
return False
40+
return True
41+
42+
2643
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
2744
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
2845
@triton.jit
@@ -208,7 +225,8 @@ def forward(
208225

209226
# NVFP4 quantized path - uses optimized fused CUDA kernel
210227
# Fuses: SiLU gating + Group RMSNorm + FP4 quantization
211-
if self.is_nvfp4 and z is not None and not self.norm_before_gate:
228+
if self.is_nvfp4 and z is not None and not self.norm_before_gate and \
229+
fused_gated_rmsnorm_quant_shape_ok(self.hidden_size, self.group_size):
212230
if self.nvfp4_scale is None:
213231
raise ValueError(
214232
"RMSNormGated NVFP4 output requested but no `nvfp4_scale` is attached. "

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -39,6 +39,7 @@
3939
from .fuse_elementwise_ops import (extract_transpose_xbc_prefill,
4040
fused_split_rearrange_after_conv1d)
4141
from .layernorm_gated import RMSNorm as RMSNormGated
42+
from .layernorm_gated import fused_gated_rmsnorm_quant_shape_ok
4243
from .selective_state_update import \
4344
selective_state_update as selective_state_update_native
4445
from .ssd_combined import mamba_chunk_scan_combined
@@ -234,7 +235,9 @@ def __init__(
234235

235236
def post_load_weights(self):
236237
"""Post-process after loading weights."""
237-
if self.norm.is_nvfp4 and self.norm.nvfp4_scale is None:
238+
if (self.norm.is_nvfp4 and fused_gated_rmsnorm_quant_shape_ok(
239+
self.norm.hidden_size, self.norm.group_size)
240+
and self.norm.nvfp4_scale is None):
238241
self._try_attach_nvfp4_scale()
239242

240243
def _try_attach_nvfp4_scale(self):

0 commit comments

Comments
 (0)