Skip to content

Commit bab9e36

Browse files
mgoinLiuweixiong0118
authored andcommitted
[Perf][gpt-oss] Downgrade triton_kernels to v3.5.1 (vllm-project#43135)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Liuweixiong0118 <lwx34158427@gmail.com>
1 parent 90fe3f1 commit bab9e36

3 files changed

Lines changed: 73 additions & 57 deletions

File tree

cmake/external_projects/triton_kernels.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
22

3-
set(DEFAULT_TRITON_KERNELS_TAG "v3.6.0")
3+
set(DEFAULT_TRITON_KERNELS_TAG "v3.5.1")
44

55
# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
66
# be directly set to the triton_kernels python directory.

tests/kernels/quantization/test_mxfp4_triton_ep.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,33 @@
44
Tests that triton_kernel_moe_forward correctly applies expert_map
55
remapping when expert parallelism (EP) is enabled.
66
7-
Both EP and non-EP paths use topk + make_routing_data. When expert_map
8-
is provided, global expert IDs are remapped to local IDs before building
9-
routing structures.
7+
When expert_map is provided, global expert IDs are remapped to local IDs
8+
via topk + expert_map remap + make_routing_data before building routing
9+
structures, and the expert_map passed downstream to triton_kernel_fused_experts
10+
is None (already applied).
1011
"""
1112

1213
from unittest.mock import MagicMock, patch
1314

14-
import pytest
1515
import torch
1616

1717

1818
class TestTritonMoeForwardExpertMap:
1919
"""Test that triton_kernel_moe_forward applies expert_map remapping
2020
when expert_map is provided (EP active)."""
2121

22-
@pytest.mark.parametrize("expert_map_present", [False, True])
23-
def test_routing_path_selection(self, expert_map_present):
24-
"""Verify that both EP and non-EP paths use topk + make_routing_data,
25-
and that expert_map remapping is applied when present."""
26-
22+
def test_expert_map_remap(self):
2723
device = "cuda" if torch.cuda.is_available() else "cpu"
28-
mock_expert_map = (
29-
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
30-
)
24+
mock_expert_map = torch.tensor([0, -1, 1, -1], device=device)
3125

3226
from vllm.utils.import_utils import import_triton_kernels
3327

3428
import_triton_kernels()
3529

30+
mock_routing_data = MagicMock()
31+
mock_gather = MagicMock()
32+
mock_scatter = MagicMock()
33+
3634
with (
3735
patch("triton_kernels.topk.topk") as mock_topk,
3836
patch(
@@ -48,14 +46,11 @@ def test_routing_path_selection(self, expert_map_present):
4846
triton_kernel_moe_forward,
4947
)
5048

51-
mock_routing_data = MagicMock()
52-
mock_gather = MagicMock()
53-
mock_scatter = MagicMock()
54-
5549
sparse_result = MagicMock()
5650
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
5751
sparse_result.vals = torch.tensor([[0.6, 0.4]])
5852
mock_topk.return_value = sparse_result
53+
5954
mock_make_routing.return_value = (
6055
mock_routing_data,
6156
mock_gather,
@@ -79,14 +74,10 @@ def test_routing_path_selection(self, expert_map_present):
7974
expert_map=mock_expert_map,
8075
)
8176

82-
# Both paths use topk + make_routing_data
8377
mock_topk.assert_called_once()
8478
mock_make_routing.assert_called_once()
8579

86-
if expert_map_present:
87-
# expert_map should be None in the fused_experts call
88-
# (already applied)
89-
call_kwargs = mock_fused_experts.call_args
90-
assert call_kwargs[1].get("expert_map") is None or (
91-
len(call_kwargs[0]) > 0
92-
)
80+
# expert_map should be None in the fused_experts call
81+
# (already applied).
82+
call_kwargs = mock_fused_experts.call_args
83+
assert call_kwargs[1].get("expert_map") is None or (len(call_kwargs[0]) > 0)

vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,16 @@ def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
209209
_bm.make_bitmatrix_metadata = _make_bitmatrix_metadata_pow2_safe
210210

211211

212+
# Two API generations of triton_kernels are supported:
213+
# - v3.5.1 (the version bundled with vLLM): exposes `routing()` and
214+
# `routing_from_bitmatrix()` in triton_kernels.routing; the `Bitmatrix`
215+
# constructor takes a `scratchpad` argument.
216+
# - v3.6.0+: removes the `routing` module in favor of a `SparseMatrix`
217+
# based path, and adds a `dtype=BIT` kwarg to `Bitmatrix`. Used only
218+
# when the user has triton_kernels installed system-wide at v3.6.0+.
219+
#
220+
# `use_legacy_triton_kernels` selects between them at import time based on
221+
# whether `SparseMatrix` is importable.
212222
use_legacy_triton_kernels = False
213223

214224
if has_triton_kernels():
@@ -233,11 +243,10 @@ def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
233243
make_ragged_tensor_metadata,
234244
)
235245
except ImportError:
236-
if current_platform.is_rocm():
237-
logger.warning_once("Using legacy triton_kernels on ROCm")
238-
use_legacy_triton_kernels = True
239-
else:
240-
raise
246+
# TODO(mgoin): drop the v3.5.1 pin and remove this fallback once
247+
# the gpt-oss perf regression in v3.6.0+ is resolved upstream.
248+
# Tracking: https://github.com/triton-lang/triton/issues/9969
249+
use_legacy_triton_kernels = True
241250
if not use_legacy_triton_kernels:
242251
_patch_make_bitmatrix_metadata()
243252
except (AttributeError, ImportError) as e:
@@ -311,38 +320,54 @@ def triton_kernel_moe_forward(
311320
unpadded_N_w2=None,
312321
unpadded_K_w2=None,
313322
) -> torch.Tensor:
314-
from triton_kernels.topk import topk as topk_fn
315-
316323
sm_first = not renormalize
317-
logits = gating_output
318-
if sm_first:
319-
logits = torch.softmax(logits, dim=-1)
320-
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
321-
# topk may return a tuple (vals, indx, bitmatrix) or a
322-
# SparseMatrix depending on the triton_kernels version.
323-
if isinstance(topk_result, tuple):
324-
topk_weights, topk_ids_raw, _ = topk_result
325-
else:
326-
topk_weights = topk_result.vals
327-
topk_ids_raw = topk_result.indx
328324

329-
if expert_map is not None:
330-
# topk_ids_raw contains global expert IDs - remap to local.
331-
topk_ids = expert_map[topk_ids_raw.to(torch.long)]
332-
local_num_experts = w1.shape[0]
333-
routing_data, gather_idx, scatter_idx = make_routing_data(
334-
topk_ids, topk_weights, local_num_experts
325+
# When no expert map is provided (no EP), call the fused `routing()`
326+
# kernel directly. It combines softmax, topk, bitmatrix packing, and
327+
# routing-metadata construction in a single launch, instead of the
328+
# three separate kernels used by the generic path below.
329+
# Only available in the legacy (v3.5.1) API; the v3.6.0+ path inlines
330+
# equivalent logic via SparseMatrix in `make_routing_data`.
331+
if use_legacy_triton_kernels and expert_map is None:
332+
from triton_kernels.routing import routing as fused_routing
333+
334+
routing_data, gather_idx, scatter_idx = fused_routing(
335+
gating_output, topk, sm_first=sm_first
335336
)
336-
# expert_map already applied; pass None downstream.
337337
effective_expert_map = None
338-
effective_global_num_experts = local_num_experts
339-
else:
340-
topk_ids = topk_ids_raw.to(torch.long)
341-
routing_data, gather_idx, scatter_idx = make_routing_data(
342-
topk_ids, topk_weights, gating_output.shape[-1]
343-
)
344-
effective_expert_map = expert_map
345338
effective_global_num_experts = global_num_experts
339+
else:
340+
from triton_kernels.topk import topk as topk_fn
341+
342+
logits = gating_output
343+
if sm_first:
344+
logits = torch.softmax(logits, dim=-1)
345+
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
346+
# topk may return a tuple (vals, indx, bitmatrix) or a
347+
# SparseMatrix depending on the triton_kernels version.
348+
if isinstance(topk_result, tuple):
349+
topk_weights, topk_ids_raw, _ = topk_result
350+
else:
351+
topk_weights = topk_result.vals
352+
topk_ids_raw = topk_result.indx
353+
354+
if expert_map is not None:
355+
# topk_ids_raw contains global expert IDs - remap to local.
356+
topk_ids = expert_map[topk_ids_raw.to(torch.long)]
357+
local_num_experts = w1.shape[0]
358+
routing_data, gather_idx, scatter_idx = make_routing_data(
359+
topk_ids, topk_weights, local_num_experts
360+
)
361+
# expert_map already applied; pass None downstream.
362+
effective_expert_map = None
363+
effective_global_num_experts = local_num_experts
364+
else:
365+
topk_ids = topk_ids_raw.to(torch.long)
366+
routing_data, gather_idx, scatter_idx = make_routing_data(
367+
topk_ids, topk_weights, gating_output.shape[-1]
368+
)
369+
effective_expert_map = expert_map
370+
effective_global_num_experts = global_num_experts
346371

347372
output = torch.empty_like(hidden_states)
348373
effective_quant_config = (

0 commit comments

Comments
 (0)