@@ -209,6 +209,16 @@ def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
209209 _bm .make_bitmatrix_metadata = _make_bitmatrix_metadata_pow2_safe
210210
211211
212+ # Two API generations of triton_kernels are supported:
213+ # - v3.5.1 (the version bundled with vLLM): exposes `routing()` and
214+ # `routing_from_bitmatrix()` in triton_kernels.routing; the `Bitmatrix`
215+ # constructor takes a `scratchpad` argument.
216+ # - v3.6.0+: removes the `routing` module in favor of a `SparseMatrix`
217+ # based path, and adds a `dtype=BIT` kwarg to `Bitmatrix`. Used only
218+ # when the user has triton_kernels installed system-wide at v3.6.0+.
219+ #
220+ # `use_legacy_triton_kernels` selects between them at import time based on
221+ # whether `SparseMatrix` is importable.
212222use_legacy_triton_kernels = False
213223
214224if has_triton_kernels ():
@@ -233,11 +243,10 @@ def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
233243 make_ragged_tensor_metadata ,
234244 )
235245 except ImportError :
236- if current_platform .is_rocm ():
237- logger .warning_once ("Using legacy triton_kernels on ROCm" )
238- use_legacy_triton_kernels = True
239- else :
240- raise
246+ # TODO(mgoin): drop the v3.5.1 pin and remove this fallback once
247+ # the gpt-oss perf regression in v3.6.0+ is resolved upstream.
248+ # Tracking: https://github.com/triton-lang/triton/issues/9969
249+ use_legacy_triton_kernels = True
241250 if not use_legacy_triton_kernels :
242251 _patch_make_bitmatrix_metadata ()
243252 except (AttributeError , ImportError ) as e :
@@ -311,38 +320,54 @@ def triton_kernel_moe_forward(
311320 unpadded_N_w2 = None ,
312321 unpadded_K_w2 = None ,
313322) -> torch .Tensor :
314- from triton_kernels .topk import topk as topk_fn
315-
316323 sm_first = not renormalize
317- logits = gating_output
318- if sm_first :
319- logits = torch .softmax (logits , dim = - 1 )
320- topk_result = topk_fn (logits , topk , apply_softmax = not sm_first )
321- # topk may return a tuple (vals, indx, bitmatrix) or a
322- # SparseMatrix depending on the triton_kernels version.
323- if isinstance (topk_result , tuple ):
324- topk_weights , topk_ids_raw , _ = topk_result
325- else :
326- topk_weights = topk_result .vals
327- topk_ids_raw = topk_result .indx
328324
329- if expert_map is not None :
330- # topk_ids_raw contains global expert IDs - remap to local.
331- topk_ids = expert_map [topk_ids_raw .to (torch .long )]
332- local_num_experts = w1 .shape [0 ]
333- routing_data , gather_idx , scatter_idx = make_routing_data (
334- topk_ids , topk_weights , local_num_experts
325+ # When no expert map is provided (no EP), call the fused `routing()`
326+ # kernel directly. It combines softmax, topk, bitmatrix packing, and
327+ # routing-metadata construction in a single launch, instead of the
328+ # three separate kernels used by the generic path below.
329+ # Only available in the legacy (v3.5.1) API; the v3.6.0+ path inlines
330+ # equivalent logic via SparseMatrix in `make_routing_data`.
331+ if use_legacy_triton_kernels and expert_map is None :
332+ from triton_kernels .routing import routing as fused_routing
333+
334+ routing_data , gather_idx , scatter_idx = fused_routing (
335+ gating_output , topk , sm_first = sm_first
335336 )
336- # expert_map already applied; pass None downstream.
337337 effective_expert_map = None
338- effective_global_num_experts = local_num_experts
339- else :
340- topk_ids = topk_ids_raw .to (torch .long )
341- routing_data , gather_idx , scatter_idx = make_routing_data (
342- topk_ids , topk_weights , gating_output .shape [- 1 ]
343- )
344- effective_expert_map = expert_map
345338 effective_global_num_experts = global_num_experts
339+ else :
340+ from triton_kernels .topk import topk as topk_fn
341+
342+ logits = gating_output
343+ if sm_first :
344+ logits = torch .softmax (logits , dim = - 1 )
345+ topk_result = topk_fn (logits , topk , apply_softmax = not sm_first )
346+ # topk may return a tuple (vals, indx, bitmatrix) or a
347+ # SparseMatrix depending on the triton_kernels version.
348+ if isinstance (topk_result , tuple ):
349+ topk_weights , topk_ids_raw , _ = topk_result
350+ else :
351+ topk_weights = topk_result .vals
352+ topk_ids_raw = topk_result .indx
353+
354+ if expert_map is not None :
355+ # topk_ids_raw contains global expert IDs - remap to local.
356+ topk_ids = expert_map [topk_ids_raw .to (torch .long )]
357+ local_num_experts = w1 .shape [0 ]
358+ routing_data , gather_idx , scatter_idx = make_routing_data (
359+ topk_ids , topk_weights , local_num_experts
360+ )
361+ # expert_map already applied; pass None downstream.
362+ effective_expert_map = None
363+ effective_global_num_experts = local_num_experts
364+ else :
365+ topk_ids = topk_ids_raw .to (torch .long )
366+ routing_data , gather_idx , scatter_idx = make_routing_data (
367+ topk_ids , topk_weights , gating_output .shape [- 1 ]
368+ )
369+ effective_expert_map = expert_map
370+ effective_global_num_experts = global_num_experts
346371
347372 output = torch .empty_like (hidden_states )
348373 effective_quant_config = (
0 commit comments