Skip to content

Commit a78e733

Browse files
authored
Support topk_sigmoid kernel for MoE (vllm-project#148)
1 parent 9d35aa8 commit a78e733

8 files changed

Lines changed: 475 additions & 173 deletions

File tree

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import torch
99
import triton
1010

11-
from tests.ops.topk_softmax_op import fused_topk, topk_softmax
11+
from tests.ops.topk_op import (fused_topk_sigmoid, fused_topk_softmax,
12+
topk_sigmoid, topk_softmax)
1213

1314

1415
@torch.compile
@@ -28,6 +29,23 @@ def topk_softmax_compile(
2829
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
2930

3031

32+
@torch.compile
33+
def topk_sigmoid_compile(
34+
hidden_states: torch.Tensor,
35+
gating_output: torch.Tensor,
36+
topk: int,
37+
renormalize: bool,
38+
indices_type: Optional[torch.dtype] = None,
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
41+
routing_weights = torch.sigmoid(gating_output).to(torch.float32)
42+
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
43+
44+
if renormalize:
45+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
46+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
47+
48+
3149
n_token_range = [1, 64, 256]
3250
n_expert_range = [16, 192, 512, 1024]
3351
topk_range = [2, 4]
@@ -43,7 +61,19 @@ def topk_softmax_compile(
4361
))
4462

4563

46-
def get_benchmark():
64+
def get_benchmark(scoring_func: str):
65+
if scoring_func == "softmax":
66+
fused_fn = fused_topk_softmax
67+
native_fn = topk_softmax
68+
compile_fn = topk_softmax_compile
69+
plot_name = "topk_softmax-perf"
70+
elif scoring_func == "sigmoid":
71+
fused_fn = fused_topk_sigmoid
72+
native_fn = topk_sigmoid
73+
compile_fn = topk_sigmoid_compile
74+
plot_name = "topk_sigmoid-perf"
75+
else:
76+
raise ValueError(f"Unsupported scoring_func: {scoring_func}")
4777

4878
@triton.testing.perf_report(
4979
triton.testing.Benchmark(
@@ -61,7 +91,7 @@ def get_benchmark():
6191
styles=[("blue", "-"), ("green", "-"), ("orange", "-"),
6292
("red", "-")],
6393
ylabel="us",
64-
plot_name="topk_softmax-perf",
94+
plot_name=plot_name,
6595
args={},
6696
))
6797
def benchmark(
@@ -84,26 +114,26 @@ def benchmark(
84114

85115
if provider == "vllm":
86116
ms, min_ms, max_ms = triton.testing.do_bench(
87-
lambda: fused_topk(hidden_states=hidden_states,
88-
gating_output=gating_output,
89-
topk=topk,
90-
renormalize=renormalize),
117+
lambda: fused_fn(hidden_states=hidden_states,
118+
gating_output=gating_output,
119+
topk=topk,
120+
renormalize=renormalize),
91121
quantiles=quantiles,
92122
)
93123
elif provider == "native":
94124
ms, min_ms, max_ms = triton.testing.do_bench(
95-
lambda: topk_softmax(hidden_states=hidden_states,
96-
gating_output=gating_output,
97-
topk=topk,
98-
renormalize=renormalize),
125+
lambda: native_fn(hidden_states=hidden_states,
126+
gating_output=gating_output,
127+
topk=topk,
128+
renormalize=renormalize),
99129
quantiles=quantiles,
100130
)
101131
elif provider == "compile":
102132
ms, min_ms, max_ms = triton.testing.do_bench(
103-
lambda: topk_softmax_compile(hidden_states=hidden_states,
104-
gating_output=gating_output,
105-
topk=topk,
106-
renormalize=renormalize),
133+
lambda: compile_fn(hidden_states=hidden_states,
134+
gating_output=gating_output,
135+
topk=topk,
136+
renormalize=renormalize),
107137
quantiles=quantiles,
108138
)
109139

@@ -113,15 +143,22 @@ def benchmark(
113143

114144

115145
if __name__ == "__main__":
116-
parser = ArgumentParser(description="Benchmark the topk_softmax kernel.")
146+
parser = ArgumentParser(description="Benchmark the topk kernel.")
147+
parser.add_argument(
148+
"--scoring-func",
149+
choices=["softmax", "sigmoid"],
150+
default="softmax",
151+
help="Scoring function to benchmark",
152+
)
117153
parser.add_argument(
118154
"--save-path",
119155
type=str,
120-
default="./configs/topk_softmax/",
121-
help="Path to save topk_softmax benchmark results",
156+
default="./configs/topk/",
157+
help="Path to save topk benchmark results",
122158
)
123159

124160
args = parser.parse_args()
125161

126-
benchmark = get_benchmark()
127-
benchmark.run(print_data=True, save_path=args.save_path)
162+
benchmark = get_benchmark(args.scoring_func)
163+
save_path = f"{args.save_path.rstrip('/')}/{args.scoring_func}"
164+
benchmark.run(print_data=True, save_path=save_path)

csrc/moe/moe_ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ void topk_softmax(
6464
const bool renormalize,
6565
std::optional<torch::Tensor> bias);
6666

67+
void topk_sigmoid(
68+
torch::Tensor& topk_weights,
69+
torch::Tensor& topk_indices,
70+
torch::Tensor& token_expert_indices,
71+
torch::Tensor& gating_output,
72+
const bool renormalize,
73+
std::optional<torch::Tensor> bias);
74+
6775
void moe_gather(
6876
torch::Tensor& output,
6977
const torch::Tensor& moe_output,

0 commit comments

Comments
 (0)