File tree Expand file tree Collapse file tree 1 file changed +20
-1
lines changed
Expand file tree Collapse file tree 1 file changed +20
-1
lines changed Original file line number Diff line number Diff line change 11import paddle
2+ import sys
3+
4+ def swap_torch_guard (fn ):
5+ def wrapped_fn (* args , ** kwargs ):
6+ if "torch" not in sys .modules :
7+ return fn (* args , ** kwargs )
8+ torch_module = sys .modules ["torch" ]
9+ sys .modules ["torch" ] = paddle
10+ try :
11+ return fn (* args , ** kwargs )
12+ finally :
13+ sys .modules ["torch" ] = torch_module
14+
15+ return wrapped_fn
16+
217
318def wrap_triton_kernel (triton_kernel ):
419 class WrappedTritonKernel :
520 def __init__ (self , kernel ):
621 self .kernel = kernel
722
823 def __getitem__ (self , index ):
9- return paddle .use_compat_guard (enable = True , silent = True )(self .kernel [index ])
24+ return swap_torch_guard (self .kernel [index ])
25+
26+ def __getattr__ (self , name ):
27+ return getattr (self .kernel , name )
28+
1029 return WrappedTritonKernel (triton_kernel )
You can’t perform that action at this time.
0 commit comments