Skip to content

Commit 4dee0f1

Browse files
committed
patch for cute 4.2.1
1 parent a6f794e commit 4dee0f1

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

sonicmoe/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import paddle
66
import inspect
7+
import cutlass.cute
78

89
if not (hasattr(paddle.library.CustomOpDef, "__call__") and inspect.isfunction(paddle.library.CustomOpDef.__call__)):
910
def __call__(self, *args, **kwargs):
@@ -22,6 +23,24 @@ def torch_compat_empty(*args, **kwargs):
2223
}
2324
)
2425

26+
def cute_tensor_init(
27+
self,
28+
tensor,
29+
assumed_align=None,
30+
):
31+
# If tensor is already a DLPack object, use it directly
32+
if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"):
33+
self._dlpack_data = tensor
34+
else:
35+
self._dlpack_data = tensor.__dlpack__(stream=-1) # Dont sync in cute 4.2.1, this already fixed in 4.3.0
36+
self._dltensor_wrapper = None
37+
self._assumed_align = assumed_align
38+
self._is_dynamic = False
39+
self._memref_desc = None
40+
self._dtype = None
41+
42+
cutlass.cute.runtime._Tensor.__init__ = cute_tensor_init
43+
2544
from .count_cumsum import count_cumsum
2645
from .enums import KernelBackendMoE
2746
from .functional import enable_quack_gemm, moe_TC_softmax_topk_layer

0 commit comments

Comments
 (0)