@@ -135,3 +135,70 @@ def test_fused_topk_bias(
135135 topk_weights_ref .to (torch .float32 ), topk_weights , atol = 1e-2 , rtol = 1e-2
136136 )
137137 torch .testing .assert_close (topk_ids_ref .to (torch .int32 ), topk_ids , atol = 0 , rtol = 0 )
138+
139+
140+ @pytest .mark .skipif (
141+ not current_platform .is_cuda (), reason = "This test is skipped on non-CUDA platform."
142+ )
143+ @pytest .mark .parametrize ("num_experts" , [6 , 8 , 16 ])
144+ @pytest .mark .parametrize ("topk" , [3 , 4 ])
145+ @pytest .mark .parametrize ("scoring_func" , ["softmax" , "sigmoid" ])
146+ @pytest .mark .parametrize ("bad_value" , [float ("nan" ), float ("inf" )])
147+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .half , torch .float32 ])
148+ def test_fused_topk_nan_inf_clamp (
149+ num_experts : int ,
150+ topk : int ,
151+ scoring_func : str ,
152+ bad_value : float ,
153+ dtype : torch .dtype ,
154+ ):
155+ """Regression test for the NaN/Inf clamp in topk_softmax_kernels.cu.
156+
157+ Degenerate hidden states (e.g., from CUDA graph padding) can produce
158+ NaN/Inf gating logits. Without the clamp, softmax/sigmoid outputs are
159+ NaN and the argmax loop picks expert 0 for every top-k slot (since
160+ "NaN > NaN" is false per IEEE 754), yielding duplicate expert IDs that
161+ crash downstream MoE sort kernels. The fix clamps NaN/Inf to 0 before
162+ argmax so index tie-breaking selects unique experts [0, 1, ..., k-1].
163+ """
164+ torch .manual_seed (0 )
165+ num_tokens = 4
166+ hidden_size = 1024
167+ hidden_states = torch .randn ((num_tokens , hidden_size ), dtype = dtype , device = "cuda" )
168+
169+ # Row 0: all normal. Rows 1-3: fully poisoned with NaN or Inf.
170+ gating_output = torch .randn ((num_tokens , num_experts ), dtype = dtype , device = "cuda" )
171+ gating_output [1 :, :] = bad_value
172+
173+ topk_weights , topk_ids , _ = fused_topk (
174+ hidden_states = hidden_states ,
175+ gating_output = gating_output ,
176+ topk = topk ,
177+ renormalize = False ,
178+ scoring_func = scoring_func ,
179+ )
180+
181+ # Normal row must still match the torch reference.
182+ ref_weights , ref_ids = torch_topk (
183+ gating_output = gating_output [:1 ],
184+ topk = topk ,
185+ renormalize = False ,
186+ scoring_func = scoring_func ,
187+ )
188+ torch .testing .assert_close (
189+ ref_weights .to (torch .float32 ), topk_weights [:1 ], atol = 1e-2 , rtol = 1e-2
190+ )
191+ torch .testing .assert_close (ref_ids .to (torch .int32 ), topk_ids [:1 ], atol = 0 , rtol = 0 )
192+
193+ # Poisoned rows: IDs must be unique (no duplicates) and weights must be
194+ # finite (no NaN/Inf propagation into downstream MoE kernels).
195+ for row in range (1 , num_tokens ):
196+ row_ids = topk_ids [row ]
197+ assert row_ids .unique ().numel () == topk , (
198+ f"Row { row } has duplicate expert IDs { row_ids .tolist ()} "
199+ f"(bad_value={ bad_value } , scoring_func={ scoring_func } )"
200+ )
201+ assert torch .isfinite (topk_weights [row ]).all (), (
202+ f"Row { row } has non-finite weights { topk_weights [row ].tolist ()} "
203+ f"(bad_value={ bad_value } , scoring_func={ scoring_func } )"
204+ )
0 commit comments