Skip to content

check_numerics doesn't work inside repeat.py #815

Open
@ds-hwang

Description

@ds-hwang

assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}" doesn't work with traced x.

assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
Traceback (most recent call last):
  File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 782, in __bool__
    return self.aval._bool(self)
  File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 1538, in error
    raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function fn at /Users/dongseong/Workspaces/axlearn/axlearn/common/base_layer.py:329 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument kwargs['inputs'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

So check_numerics doesn't work inside jit, pmap, and scan.

def check_numerics(x: Tensor, msg_fmt: str = "", **msg_kwargs):
    """Checks that all elements in `x` are finite."""
    global _enable_numeric_checks  # pylint: disable=global-statement,global-variable-not-assigned
    if _enable_numeric_checks:
        assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
    return x

There is jax checkify, but it requires wrapped by check.checkify(main). It's not trivial to use it in axlearn.
https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions