File tree Expand file tree Collapse file tree 1 file changed +15
-2
lines changed
Expand file tree Collapse file tree 1 file changed +15
-2
lines changed Original file line number Diff line number Diff line change 1- import paddle
21import sys
32
3+ import paddle
4+
5+ original_paddle_empty = paddle .empty
6+
7+
8+ def torch_compat_empty (* args , ** kwargs ):
9+ if "device" in kwargs and kwargs ["device" ] == "cuda" :
10+ del kwargs ["device" ]
11+ return original_paddle_empty (* args , ** kwargs )
12+
13+
414def swap_torch_guard (fn ):
515 def wrapped_fn (* args , ** kwargs ):
616 if "torch" not in sys .modules :
717 return fn (* args , ** kwargs )
818 torch_module = sys .modules ["torch" ]
19+ original_paddle_empty = paddle .empty
920 sys .modules ["torch" ] = paddle
21+ paddle .empty = torch_compat_empty
1022 try :
1123 return fn (* args , ** kwargs )
1224 finally :
1325 sys .modules ["torch" ] = torch_module
26+ paddle .empty = original_paddle_empty
1427
1528 return wrapped_fn
1629
@@ -26,4 +39,4 @@ def __getitem__(self, index):
2639 def __getattr__ (self , name ):
2740 return getattr (self .kernel , name )
2841
29- return WrappedTritonKernel (triton_kernel )
42+ return WrappedTritonKernel (triton_kernel )
You can’t perform that action at this time.
0 commit comments