Skip to content

Commit 84813bd

Browse files
committed
wrap triton kernel with getitem
1 parent ee21e19 commit 84813bd

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

sonicmoe/functional/backward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
2222

23-
import paddle
23+
from ..triton_utils import wrap_triton_kernel
2424

2525
def _get_autotune_configs_for_db2_and_ds() -> list[triton.Config]:
2626
configs = []
@@ -29,7 +29,7 @@ def _get_autotune_configs_for_db2_and_ds() -> list[triton.Config]:
2929
return configs
3030

3131

32-
@paddle.use_compat_guard(enable=True, scope={"triton"}, silent=True)
32+
@wrap_triton_kernel
3333
@triton.autotune(
3434
configs=_get_autotune_configs_for_db2_and_ds(),
3535
key=["H", "E"],
@@ -126,7 +126,7 @@ def _prune_triton_autotune_config(configs, nargs, **kw):
126126
return pruned_configs
127127

128128

129-
@paddle.use_compat_guard(enable=True, scope={"triton"}, silent=True)
129+
@wrap_triton_kernel
130130
@triton.autotune(
131131
configs=_get_autotune_configs_for_db1(),
132132
key=["I", "E"],
@@ -171,7 +171,7 @@ def db1_kernel(
171171
tl.store(db1_ptr + Eidx * I + i_offsets, db1_acc, mask=i_mask)
172172

173173

174-
@paddle.use_compat_guard(enable=True, scope={"triton"}, silent=True)
174+
@wrap_triton_kernel
175175
@triton.jit
176176
def _colsum_smallN_kernel(
177177
y_ptr, # *mut T, shape [M]
@@ -488,7 +488,7 @@ def _token_broadcast_backward(
488488
)
489489

490490

491-
@paddle.use_compat_guard(enable=True, scope={"triton"}, silent=True)
491+
@wrap_triton_kernel
492492
@triton.jit
493493
def _softmax_bwd_scatter_small_kernel(
494494
dlogits_ptr,
@@ -561,7 +561,7 @@ def _softmax_topk_bwd(
561561
)
562562

563563

564-
@paddle.use_compat_guard(enable=True, scope={"triton"}, silent=True)
564+
@wrap_triton_kernel
565565
@triton.jit
566566
def _topk_bwd_scatter_small_kernel(
567567
dlogits_full_ptr,

sonicmoe/functional/forward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
1717
from .topk_softmax import TopK_Softmax
1818

19-
import paddle
19+
from ..triton_utils import wrap_triton_kernel
2020

2121

2222
@torch.library.custom_op(f"{LIBRARY_NAME}::_topk_fwd", mutates_args={"values", "indices"})
@@ -204,7 +204,7 @@ def _router_forward(
204204
)
205205

206206

207-
@paddle.use_compat_guard(enable=True, scope={"triton"}, silent=True)
207+
@wrap_triton_kernel
208208
@triton.jit
209209
def _softmax_fwd_small_kernel(
210210
logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr

sonicmoe/functional/reduction_over_k_gather.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..utils import get_next_power_of_2, get_powers_of_2
2222
from .tile_scheduler import SonicMoETileScheduler
2323

24-
import paddle
24+
from ..triton_utils import wrap_triton_kernel
2525

2626

2727
def last_even(a: int):
@@ -440,7 +440,7 @@ def _prune_triton_autotune_config(configs, nargs, **kw):
440440
return pruned_configs
441441

442442

443-
@paddle.use_compat_guard(enable=True, scope={"triton"}, silent=True)
443+
@wrap_triton_kernel
444444
@triton.autotune(
445445
configs=_get_triton_autotune_configs(),
446446
key=["H", "MAX_K", "w_is_None", "is_varlen_K"],

sonicmoe/triton_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import paddle
2+
3+
def wrap_triton_kernel(triton_kernel):
4+
class WrappedTritonKernel:
5+
def __init__(self, kernel):
6+
self.kernel = kernel
7+
8+
def __getitem__(self, index):
9+
return paddle.use_compat_guard(enable=True, scope={"triton"})(self.kernel[index])
10+
return WrappedTritonKernel(triton_kernel)

0 commit comments

Comments
 (0)