Open
Description
assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
doesn't work with traced x
.
assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
Traceback (most recent call last):
File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 782, in __bool__
return self.aval._bool(self)
File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 1538, in error
raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function fn at /Users/dongseong/Workspaces/axlearn/axlearn/common/base_layer.py:329 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument kwargs['inputs'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
So check_numerics
doesn't work inside jit
, pmap
, and scan
.
def check_numerics(x: Tensor, msg_fmt: str = "", **msg_kwargs):
"""Checks that all elements in `x` are finite."""
global _enable_numeric_checks # pylint: disable=global-statement,global-variable-not-assigned
if _enable_numeric_checks:
assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
return x
There is jax checkify, but it requires wrapped by check.checkify(main)
. It's not trivial to use it in axlearn.
https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html
Metadata
Metadata
Assignees
Labels
No labels