88import torch
99import triton
1010
11- from tests .ops .topk_softmax_op import fused_topk , topk_softmax
11+ from tests .ops .topk_op import (fused_topk_sigmoid , fused_topk_softmax ,
12+ topk_sigmoid , topk_softmax )
1213
1314
1415@torch .compile
@@ -28,6 +29,23 @@ def topk_softmax_compile(
2829 return topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
2930
3031
32+ @torch .compile
33+ def topk_sigmoid_compile (
34+ hidden_states : torch .Tensor ,
35+ gating_output : torch .Tensor ,
36+ topk : int ,
37+ renormalize : bool ,
38+ indices_type : Optional [torch .dtype ] = None ,
39+ ) -> tuple [torch .Tensor , torch .Tensor ]:
40+
41+ routing_weights = torch .sigmoid (gating_output ).to (torch .float32 )
42+ topk_weights , topk_ids = torch .topk (routing_weights , topk , dim = - 1 )
43+
44+ if renormalize :
45+ topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
46+ return topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
47+
48+
3149n_token_range = [1 , 64 , 256 ]
3250n_expert_range = [16 , 192 , 512 , 1024 ]
3351topk_range = [2 , 4 ]
@@ -43,7 +61,19 @@ def topk_softmax_compile(
4361 ))
4462
4563
46- def get_benchmark ():
64+ def get_benchmark (scoring_func : str ):
65+ if scoring_func == "softmax" :
66+ fused_fn = fused_topk_softmax
67+ native_fn = topk_softmax
68+ compile_fn = topk_softmax_compile
69+ plot_name = "topk_softmax-perf"
70+ elif scoring_func == "sigmoid" :
71+ fused_fn = fused_topk_sigmoid
72+ native_fn = topk_sigmoid
73+ compile_fn = topk_sigmoid_compile
74+ plot_name = "topk_sigmoid-perf"
75+ else :
76+ raise ValueError (f"Unsupported scoring_func: { scoring_func } " )
4777
4878 @triton .testing .perf_report (
4979 triton .testing .Benchmark (
@@ -61,7 +91,7 @@ def get_benchmark():
6191 styles = [("blue" , "-" ), ("green" , "-" ), ("orange" , "-" ),
6292 ("red" , "-" )],
6393 ylabel = "us" ,
64- plot_name = "topk_softmax-perf" ,
94+ plot_name = plot_name ,
6595 args = {},
6696 ))
6797 def benchmark (
@@ -84,26 +114,26 @@ def benchmark(
84114
85115 if provider == "vllm" :
86116 ms , min_ms , max_ms = triton .testing .do_bench (
87- lambda : fused_topk (hidden_states = hidden_states ,
88- gating_output = gating_output ,
89- topk = topk ,
90- renormalize = renormalize ),
117+ lambda : fused_fn (hidden_states = hidden_states ,
118+ gating_output = gating_output ,
119+ topk = topk ,
120+ renormalize = renormalize ),
91121 quantiles = quantiles ,
92122 )
93123 elif provider == "native" :
94124 ms , min_ms , max_ms = triton .testing .do_bench (
95- lambda : topk_softmax (hidden_states = hidden_states ,
96- gating_output = gating_output ,
97- topk = topk ,
98- renormalize = renormalize ),
125+ lambda : native_fn (hidden_states = hidden_states ,
126+ gating_output = gating_output ,
127+ topk = topk ,
128+ renormalize = renormalize ),
99129 quantiles = quantiles ,
100130 )
101131 elif provider == "compile" :
102132 ms , min_ms , max_ms = triton .testing .do_bench (
103- lambda : topk_softmax_compile (hidden_states = hidden_states ,
104- gating_output = gating_output ,
105- topk = topk ,
106- renormalize = renormalize ),
133+ lambda : compile_fn (hidden_states = hidden_states ,
134+ gating_output = gating_output ,
135+ topk = topk ,
136+ renormalize = renormalize ),
107137 quantiles = quantiles ,
108138 )
109139
@@ -113,15 +143,22 @@ def benchmark(
113143
114144
115145if __name__ == "__main__" :
116- parser = ArgumentParser (description = "Benchmark the topk_softmax kernel." )
146+ parser = ArgumentParser (description = "Benchmark the topk kernel." )
147+ parser .add_argument (
148+ "--scoring-func" ,
149+ choices = ["softmax" , "sigmoid" ],
150+ default = "softmax" ,
151+ help = "Scoring function to benchmark" ,
152+ )
117153 parser .add_argument (
118154 "--save-path" ,
119155 type = str ,
120- default = "./configs/topk_softmax /" ,
121- help = "Path to save topk_softmax benchmark results" ,
156+ default = "./configs/topk /" ,
157+ help = "Path to save topk benchmark results" ,
122158 )
123159
124160 args = parser .parse_args ()
125161
126- benchmark = get_benchmark ()
127- benchmark .run (print_data = True , save_path = args .save_path )
162+ benchmark = get_benchmark (args .scoring_func )
163+ save_path = f"{ args .save_path .rstrip ('/' )} /{ args .scoring_func } "
164+ benchmark .run (print_data = True , save_path = save_path )
0 commit comments