File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed
Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change 44
55import paddle
66import inspect
7+ import cutlass .cute
78
89if not (hasattr (paddle .library .CustomOpDef , "__call__" ) and inspect .isfunction (paddle .library .CustomOpDef .__call__ )):
910 def __call__ (self , * args , ** kwargs ):
@@ -22,6 +23,24 @@ def torch_compat_empty(*args, **kwargs):
2223 }
2324)
2425
26+ def cute_tensor_init (
27+ self ,
28+ tensor ,
29+ assumed_align = None ,
30+ ):
31+ # If tensor is already a DLPack object, use it directly
32+ if hasattr (tensor , "__dlpack_device__" ) and not hasattr (tensor , "__dlpack__" ):
33+ self ._dlpack_data = tensor
34+ else :
35+ self ._dlpack_data = tensor .__dlpack__ (stream = - 1 ) # Dont sync in cute 4.2.1, this already fixed in 4.3.0
36+ self ._dltensor_wrapper = None
37+ self ._assumed_align = assumed_align
38+ self ._is_dynamic = False
39+ self ._memref_desc = None
40+ self ._dtype = None
41+
42+ cutlass .cute .runtime ._Tensor .__init__ = cute_tensor_init
43+
2544from .count_cumsum import count_cumsum
2645from .enums import KernelBackendMoE
2746from .functional import enable_quack_gemm , moe_TC_softmax_topk_layer
You can’t perform that action at this time.
0 commit comments