1111
1212import jax
1313import jax .core
14- import jaxlib
15- import jaxlib .xla_extension
1614import torch
1715import torch .func
1816import torch .utils ._pytree
1917from jax .dlpack import from_dlpack as jax_from_dlpack # type: ignore
20- from torch .utils .dlpack import to_dlpack as torch_to_dlpack # type: ignore
2118
2219from .types import (
2320 Dataclass ,
3431
3532
3633@overload
37- def torch_to_jax (value : torch .Tensor , / ) -> jax .Array :
38- ...
34+ def torch_to_jax (value : torch .Tensor , / ) -> jax .Array : ...
3935
4036
4137@overload
42- def torch_to_jax (value : torch .device , / ) -> jax .Device :
43- ...
38+ def torch_to_jax (value : torch .device , / ) -> jax .Device : ...
4439
4540
4641@overload
47- def torch_to_jax (value : tuple [torch .Tensor , ...], / ) -> tuple [jax .Array , ...]:
48- ...
42+ def torch_to_jax (value : tuple [torch .Tensor , ...], / ) -> tuple [jax .Array , ...]: ...
4943
5044
5145@overload
52- def torch_to_jax (value : list [torch .Tensor ], / ) -> list [jax .Array ]:
53- ...
46+ def torch_to_jax (value : list [torch .Tensor ], / ) -> list [jax .Array ]: ...
5447
5548
5649@overload
57- def torch_to_jax (value : NestedDict [K , torch .Tensor ], / ) -> NestedDict [K , jax .Array ]:
58- ...
50+ def torch_to_jax (value : NestedDict [K , torch .Tensor ], / ) -> NestedDict [K , jax .Array ]: ...
5951
6052
6153@overload
62- def torch_to_jax (value : Any , / ) -> Any :
63- ...
54+ def torch_to_jax (value : Any , / ) -> Any : ...
6455
6556
6657def torch_to_jax (value : Any , / ) -> Any :
@@ -99,16 +90,14 @@ def _direct_conversion(v: torch.Tensor) -> jax.Array:
9990 return jax_from_dlpack (v , copy = False )
10091
10192
102- def _to_from_dlpack (
103- v : torch .Tensor , ignore_deprecation_warning : bool = True
104- ) -> jax .Array :
93+ def _to_from_dlpack (v : torch .Tensor , ignore_deprecation_warning : bool = True ) -> jax .Array :
10594 with warnings .catch_warnings () if ignore_deprecation_warning else contextlib .nullcontext ():
10695 # Only way to get this to work for CPU seems to be with to/from dlpack... so we have to use this deprecated
10796 # conversion method for now.
10897 # todo: Should we let it though though?
10998 if ignore_deprecation_warning :
11099 warnings .filterwarnings ("ignore" , category = DeprecationWarning )
111- return jax_from_dlpack (torch_to_dlpack ( v ) , copy = False )
100+ return jax_from_dlpack (v , copy = False )
112101
113102
114103def torch_to_jax_tensor (value : torch .Tensor ) -> jax .Array :
@@ -130,7 +119,7 @@ def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array:
130119 # return _direct_conversion(value)
131120 return _to_from_dlpack (value , ignore_deprecation_warning = True )
132121
133- except jaxlib . xla_extension . XlaRuntimeError as err :
122+ except RuntimeError as err :
134123 log_once (
135124 logger ,
136125 message = (
@@ -145,7 +134,7 @@ def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array:
145134
146135 try :
147136 return _direct_conversion (value )
148- except jaxlib . xla_extension . XlaRuntimeError as err :
137+ except RuntimeError as err :
149138 log_once (
150139 logger ,
151140 message = (
0 commit comments