@@ -93,8 +93,14 @@ def _direct_conversion(v: torch.Tensor) -> jax.Array:
9393 return jax_from_dlpack (v , copy = False )
9494
9595
96- def _to_from_dlpack (v : torch .Tensor , ignore_deprecation_warning : bool = True ) -> jax .Array :
97- with warnings .catch_warnings () if ignore_deprecation_warning else contextlib .nullcontext ():
96+ def _to_from_dlpack (
97+ v : torch .Tensor , ignore_deprecation_warning : bool = True
98+ ) -> jax .Array :
99+ with (
100+ warnings .catch_warnings ()
101+ if ignore_deprecation_warning
102+ else contextlib .nullcontext ()
103+ ):
98104 # Only way to get this to work for CPU seems to be with to/from dlpack... so we have to use this deprecated
99105 # conversion method for now.
100106 # todo: Should we let it though though?
@@ -136,7 +142,29 @@ def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array:
136142 return _direct_conversion (value .flatten ()).reshape (value .shape )
137143
138144 try :
139- return _direct_conversion (value )
145+ # Try using the "new" way to convert using from_dlpack directly
146+ return jax_from_dlpack (
147+ value , device = torch_to_jax_device (value .device ), copy = None
148+ )
149+ except AssertionError as err :
150+ if not err .args [0 ].startswith ("Unexpected XLA layout override" ):
151+ raise
152+ # Some "AssertionError: Unexpected XLA layout override"
153+ # Try using the "old" way to convert using from_dlpack of a dlpack tensor.
154+ try :
155+ dlpack = torch_to_dlpack (value )
156+ return jax_from_dlpack (dlpack , copy = False )
157+ except jaxlib .xla_extension .XlaRuntimeError as err :
158+ log_once (
159+ logger ,
160+ message = (
161+ f"Unable to view tensor of shape { tuple (value .shape )} as a jax.Array in-place:\n "
162+ f"'{ err } '\n "
163+ f"Tensors of this shape will be flattened and unflattened (which may or "
164+ f"may not involve making a copy of the tensor's data)."
165+ ),
166+ level = logging .WARNING ,
167+ )
140168 except jaxlib .xla_extension .XlaRuntimeError as err :
141169 log_once (
142170 logger ,
0 commit comments