Skip to content

Commit bfd0bed

Browse files
committed
wrap triton kernel with compat guard
1 parent 62ef2dc commit bfd0bed

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

sonicmoe/functional/backward.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
2222

23+
import paddle
2324

2425
def _get_autotune_configs_for_db2_and_ds() -> list[triton.Config]:
2526
configs = []
@@ -28,6 +29,7 @@ def _get_autotune_configs_for_db2_and_ds() -> list[triton.Config]:
2829
return configs
2930

3031

32+
@paddle.use_compat_guard(enable=True, scope={"triton"})
3133
@triton.autotune(
3234
configs=_get_autotune_configs_for_db2_and_ds(),
3335
key=["H", "E"],
@@ -124,6 +126,7 @@ def _prune_triton_autotune_config(configs, nargs, **kw):
124126
return pruned_configs
125127

126128

129+
@paddle.use_compat_guard(enable=True, scope={"triton"})
127130
@triton.autotune(
128131
configs=_get_autotune_configs_for_db1(),
129132
key=["I", "E"],
@@ -168,6 +171,7 @@ def db1_kernel(
168171
tl.store(db1_ptr + Eidx * I + i_offsets, db1_acc, mask=i_mask)
169172

170173

174+
@paddle.use_compat_guard(enable=True, scope={"triton"})
171175
@triton.jit
172176
def _colsum_smallN_kernel(
173177
y_ptr, # *mut T, shape [M]
@@ -484,6 +488,7 @@ def _token_broadcast_backward(
484488
)
485489

486490

491+
@paddle.use_compat_guard(enable=True, scope={"triton"})
487492
@triton.jit
488493
def _softmax_bwd_scatter_small_kernel(
489494
dlogits_ptr,
@@ -556,6 +561,7 @@ def _softmax_topk_bwd(
556561
)
557562

558563

564+
@paddle.use_compat_guard(enable=True, scope={"triton"})
559565
@triton.jit
560566
def _topk_bwd_scatter_small_kernel(
561567
dlogits_full_ptr,

sonicmoe/functional/forward.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
1717
from .topk_softmax import TopK_Softmax
1818

19+
import paddle
20+
1921

2022
@torch.library.custom_op(f"{LIBRARY_NAME}::_topk_fwd", mutates_args={"values", "indices"})
2123
def _topk_fwd(
@@ -202,6 +204,7 @@ def _router_forward(
202204
)
203205

204206

207+
@paddle.use_compat_guard(enable=True, scope={"triton"})
205208
@triton.jit
206209
def _softmax_fwd_small_kernel(
207210
logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr

sonicmoe/functional/reduction_over_k_gather.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from ..utils import get_next_power_of_2, get_powers_of_2
2222
from .tile_scheduler import SonicMoETileScheduler
2323

24+
import paddle
25+
2426

2527
def last_even(a: int):
2628
return a if a % 2 == 0 else a - 1
@@ -438,6 +440,7 @@ def _prune_triton_autotune_config(configs, nargs, **kw):
438440
return pruned_configs
439441

440442

443+
@paddle.use_compat_guard(enable=True, scope={"triton"})
441444
@triton.autotune(
442445
configs=_get_triton_autotune_configs(),
443446
key=["H", "MAX_K", "w_is_None", "is_varlen_K"],

0 commit comments

Comments
 (0)