Skip to content

Commit cb3536b

Browse files
committed
Try to work arround AssertionError in XLA
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
1 parent d548220 commit cb3536b

2 files changed

Lines changed: 32 additions & 3 deletions

File tree

torch_jax_interop/to_jax.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,14 @@ def _direct_conversion(v: torch.Tensor) -> jax.Array:
9393
return jax_from_dlpack(v, copy=False)
9494

9595

96-
def _to_from_dlpack(v: torch.Tensor, ignore_deprecation_warning: bool = True) -> jax.Array:
97-
with warnings.catch_warnings() if ignore_deprecation_warning else contextlib.nullcontext():
96+
def _to_from_dlpack(
97+
v: torch.Tensor, ignore_deprecation_warning: bool = True
98+
) -> jax.Array:
99+
with (
100+
warnings.catch_warnings()
101+
if ignore_deprecation_warning
102+
else contextlib.nullcontext()
103+
):
98104
# Only way to get this to work for CPU seems to be with to/from dlpack... so we have to use this deprecated
99105
# conversion method for now.
100106
# todo: Should we let it though though?
@@ -136,7 +142,29 @@ def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array:
136142
return _direct_conversion(value.flatten()).reshape(value.shape)
137143

138144
try:
139-
return _direct_conversion(value)
145+
# Try using the "new" way to convert using from_dlpack directly
146+
return jax_from_dlpack(
147+
value, device=torch_to_jax_device(value.device), copy=None
148+
)
149+
except AssertionError as err:
150+
if not err.args[0].startswith("Unexpected XLA layout override"):
151+
raise
152+
# Some "AssertionError: Unexpected XLA layout override"
153+
# Try using the "old" way to convert using from_dlpack of a dlpack tensor.
154+
try:
155+
dlpack = torch_to_dlpack(value)
156+
return jax_from_dlpack(dlpack, copy=False)
157+
except jaxlib.xla_extension.XlaRuntimeError as err:
158+
log_once(
159+
logger,
160+
message=(
161+
f"Unable to view tensor of shape {tuple(value.shape)} as a jax.Array in-place:\n"
162+
f"'{err}'\n"
163+
f"Tensors of this shape will be flattened and unflattened (which may or "
164+
f"may not involve making a copy of the tensor's data)."
165+
),
166+
level=logging.WARNING,
167+
)
140168
except jaxlib.xla_extension.XlaRuntimeError as err:
141169
log_once(
142170
logger,

torch_jax_interop/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def log_once(logger: logging.Logger, message: str, level: int):
1111
logger.log(level=level, msg=message, stacklevel=2)
1212

1313

14+
# NOTE: Done like this to preserve the original function signature.
1415
log_once = functools.cache(log_once)
1516

1617

0 commit comments

Comments
 (0)