Skip to content

Commit 765ed54

Browse files
committed
update triton utils to use a lite guard
1 parent 4dee0f1 commit 765ed54

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

sonicmoe/triton_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11
import paddle
2+
import sys
3+
4+
def swap_torch_guard(fn):
5+
def wrapped_fn(*args, **kwargs):
6+
if "torch" not in sys.modules:
7+
return fn(*args, **kwargs)
8+
torch_module = sys.modules["torch"]
9+
sys.modules["torch"] = paddle
10+
try:
11+
return fn(*args, **kwargs)
12+
finally:
13+
sys.modules["torch"] = torch_module
14+
15+
return wrapped_fn
16+
217

318
def wrap_triton_kernel(triton_kernel):
419
class WrappedTritonKernel:
520
def __init__(self, kernel):
621
self.kernel = kernel
722

823
def __getitem__(self, index):
9-
return paddle.use_compat_guard(enable=True, silent=True)(self.kernel[index])
24+
return swap_torch_guard(self.kernel[index])
25+
26+
def __getattr__(self, name):
27+
return getattr(self.kernel, name)
28+
1029
return WrappedTritonKernel(triton_kernel)

0 commit comments

Comments
 (0)