|
1 | 1 | # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/layernorm_gated.py |
2 | 2 | # Copyright (c) 2024, Tri Dao, Albert Gu. |
3 | 3 | # |
4 | | -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 4 | +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
5 | 5 | # SPDX-License-Identifier: Apache-2.0 |
6 | 6 | # |
7 | 7 | # Licensed under the Apache License, Version 2.0 (the "License"); |
|
23 | 23 | from ...utils import Fp4QuantizedTensor |
24 | 24 |
|
25 | 25 |
|
| 26 | +def fused_gated_rmsnorm_quant_shape_ok(hidden_size: int, |
| 27 | + group_size: int) -> bool: |
| 28 | + """True if ``torch.ops.trtllm.fused_gated_rmsnorm_quant`` supports this shape. |
| 29 | +
|
| 30 | + Keep in sync with TORCH_CHECKs in cpp/tensorrt_llm/thop/fusedGatedRMSNormQuant.cpp. |
| 31 | + """ |
| 32 | + if group_size <= 0 or hidden_size % group_size != 0: |
| 33 | + return False |
| 34 | + if group_size % 256 != 0: |
| 35 | + return False |
| 36 | + if not (256 <= group_size <= 8192): |
| 37 | + return False |
| 38 | + if hidden_size % 16 != 0: |
| 39 | + return False |
| 40 | + return True |
| 41 | + |
| 42 | + |
26 | 43 | @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) |
27 | 44 | @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) |
28 | 45 | @triton.jit |
@@ -208,7 +225,8 @@ def forward( |
208 | 225 |
|
209 | 226 | # NVFP4 quantized path - uses optimized fused CUDA kernel |
210 | 227 | # Fuses: SiLU gating + Group RMSNorm + FP4 quantization |
211 | | - if self.is_nvfp4 and z is not None and not self.norm_before_gate: |
| 228 | + if self.is_nvfp4 and z is not None and not self.norm_before_gate and \ |
| 229 | + fused_gated_rmsnorm_quant_shape_ok(self.hidden_size, self.group_size): |
212 | 230 | if self.nvfp4_scale is None: |
213 | 231 | raise ValueError( |
214 | 232 | "RMSNormGated NVFP4 output requested but no `nvfp4_scale` is attached. " |
|
0 commit comments