Skip to content

Commit 935608e

Browse files
committed
patch for paddle empty
1 parent 765ed54 commit 935608e

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

sonicmoe/triton_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
1-
import paddle
21
import sys
32

3+
import paddle
4+
5+
original_paddle_empty = paddle.empty
6+
7+
8+
def torch_compat_empty(*args, **kwargs):
9+
if "device" in kwargs and kwargs["device"] == "cuda":
10+
del kwargs["device"]
11+
return original_paddle_empty(*args, **kwargs)
12+
13+
414
def swap_torch_guard(fn):
515
def wrapped_fn(*args, **kwargs):
616
if "torch" not in sys.modules:
717
return fn(*args, **kwargs)
818
torch_module = sys.modules["torch"]
19+
original_paddle_empty = paddle.empty
920
sys.modules["torch"] = paddle
21+
paddle.empty = torch_compat_empty
1022
try:
1123
return fn(*args, **kwargs)
1224
finally:
1325
sys.modules["torch"] = torch_module
26+
paddle.empty = original_paddle_empty
1427

1528
return wrapped_fn
1629

@@ -26,4 +39,4 @@ def __getitem__(self, index):
2639
def __getattr__(self, name):
2740
return getattr(self.kernel, name)
2841

29-
return WrappedTritonKernel(triton_kernel)
42+
return WrappedTritonKernel(triton_kernel)

0 commit comments

Comments
 (0)