Skip to content

Commit 28c2221

Browse files
authored
fix: clamp NaN/Inf in topk_softmax to prevent duplicate expert IDs (vllm-project#39391)
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
1 parent 3975eb6 commit 28c2221

2 files changed

Lines changed: 86 additions & 2 deletions

File tree

csrc/moe/topk_softmax_kernels.cu

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ __launch_bounds__(TPB) __global__
126126
{
127127
const int idx = thread_row_offset + ii;
128128
const float val = toFloat(input[idx]);
129-
const float softmax_val = expf(val - float_max) * normalizing_factor;
129+
float softmax_val = expf(val - float_max) * normalizing_factor;
130+
// Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream.
131+
if (isnan(softmax_val) || isinf(softmax_val)) softmax_val = 0.f;
130132
output[idx] = softmax_val;
131133
}
132134
}
@@ -147,7 +149,9 @@ __launch_bounds__(TPB) __global__
147149
{
148150
const int idx = thread_row_offset + ii;
149151
const float val = toFloat(input[idx]);
150-
const float sigmoid_val = 1.0f / (1.0f + __expf(-val));
152+
float sigmoid_val = 1.0f / (1.0f + __expf(-val));
153+
// Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream.
154+
if (isnan(sigmoid_val) || isinf(sigmoid_val)) sigmoid_val = 0.f;
151155
output[idx] = sigmoid_val;
152156
}
153157
}
@@ -442,6 +446,19 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
442446
}
443447
}
444448

449+
// Fix: clamp NaN/Inf values to 0 to prevent duplicate expert IDs.
450+
// NaN gating (from degenerate hidden states in CUDA graph padding) causes
451+
// softmax to produce all-NaN, which makes the argmax loop always pick
452+
// expert 0 for every top-k slot, producing duplicate expert IDs that
453+
// crash FlashInfer's three-step MoE sort.
454+
// With 0s, the argmax uses index tie-breaking to pick [0,1,2,...,k-1].
455+
#pragma unroll
456+
for (int ii = 0; ii < VPT; ++ii) {
457+
if (isnan(row_chunk[ii]) || isinf(row_chunk[ii])) {
458+
row_chunk[ii] = 0.f;
459+
}
460+
}
461+
445462
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
446463

447464
// If bias is not null, use biased value for selection

tests/kernels/moe/test_fused_topk.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,70 @@ def test_fused_topk_bias(
135135
topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2
136136
)
137137
torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0)
138+
139+
140+
@pytest.mark.skipif(
141+
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
142+
)
143+
@pytest.mark.parametrize("num_experts", [6, 8, 16])
144+
@pytest.mark.parametrize("topk", [3, 4])
145+
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
146+
@pytest.mark.parametrize("bad_value", [float("nan"), float("inf")])
147+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
148+
def test_fused_topk_nan_inf_clamp(
149+
num_experts: int,
150+
topk: int,
151+
scoring_func: str,
152+
bad_value: float,
153+
dtype: torch.dtype,
154+
):
155+
"""Regression test for the NaN/Inf clamp in topk_softmax_kernels.cu.
156+
157+
Degenerate hidden states (e.g., from CUDA graph padding) can produce
158+
NaN/Inf gating logits. Without the clamp, softmax/sigmoid outputs are
159+
NaN and the argmax loop picks expert 0 for every top-k slot (since
160+
"NaN > NaN" is false per IEEE 754), yielding duplicate expert IDs that
161+
crash downstream MoE sort kernels. The fix clamps NaN/Inf to 0 before
162+
argmax so index tie-breaking selects unique experts [0, 1, ..., k-1].
163+
"""
164+
torch.manual_seed(0)
165+
num_tokens = 4
166+
hidden_size = 1024
167+
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
168+
169+
# Row 0: all normal. Rows 1-3: fully poisoned with NaN or Inf.
170+
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
171+
gating_output[1:, :] = bad_value
172+
173+
topk_weights, topk_ids, _ = fused_topk(
174+
hidden_states=hidden_states,
175+
gating_output=gating_output,
176+
topk=topk,
177+
renormalize=False,
178+
scoring_func=scoring_func,
179+
)
180+
181+
# Normal row must still match the torch reference.
182+
ref_weights, ref_ids = torch_topk(
183+
gating_output=gating_output[:1],
184+
topk=topk,
185+
renormalize=False,
186+
scoring_func=scoring_func,
187+
)
188+
torch.testing.assert_close(
189+
ref_weights.to(torch.float32), topk_weights[:1], atol=1e-2, rtol=1e-2
190+
)
191+
torch.testing.assert_close(ref_ids.to(torch.int32), topk_ids[:1], atol=0, rtol=0)
192+
193+
# Poisoned rows: IDs must be unique (no duplicates) and weights must be
194+
# finite (no NaN/Inf propagation into downstream MoE kernels).
195+
for row in range(1, num_tokens):
196+
row_ids = topk_ids[row]
197+
assert row_ids.unique().numel() == topk, (
198+
f"Row {row} has duplicate expert IDs {row_ids.tolist()} "
199+
f"(bad_value={bad_value}, scoring_func={scoring_func})"
200+
)
201+
assert torch.isfinite(topk_weights[row]).all(), (
202+
f"Row {row} has non-finite weights {topk_weights[row].tolist()} "
203+
f"(bad_value={bad_value}, scoring_func={scoring_func})"
204+
)

0 commit comments

Comments
 (0)