From d40328105e65901051d18321cca648e9cdfa2760 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 22 Oct 2024 06:01:15 +0000 Subject: [PATCH 1/2] restore test_apply_paddings_check runtime_checks test The main idea is that we need to call `jax.effects_barrier()`, because the error may be raised in an XLA computation that is asynchronous with the main Python thread and therefore we need to block. (There may have been a recent change in behavior, where JAX runs more computations asynchronously on the CPU backend.) We could put that call to `jax.effects_barrier()` in the test code (and corresponding user code), or we could bulid it into the `runtime_checks` context manager. Currently this commit does the latter. I also tweaked the `runtime_checks` logic to use a `try/finally` pattern to restore the state when the context is exited, even when it's exited via exception. We may want to do the same to context managers like `numeric_checks`. While the test now passes, there is a gross warning printed about "Exception ignored in atexit callback". That may be a JAX internal bug, or it may be some quirk of CPython 3.10; I haven't investigated further. Let me know if that seems like a problem. --- axlearn/common/transducer_test.py | 26 +++++++++++++------------- axlearn/common/utils.py | 7 +++++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/axlearn/common/transducer_test.py b/axlearn/common/transducer_test.py index d56a7c288..74c6181e3 100644 --- a/axlearn/common/transducer_test.py +++ b/axlearn/common/transducer_test.py @@ -4,6 +4,7 @@ # pylint: disable=duplicate-code,invalid-name import jax +import jaxlib import numpy as np import tensorflow as tf from absl import logging @@ -28,7 +29,7 @@ log_prob_suffix_alignments, log_probs_from_blank_and_tokens, ) -from axlearn.common.utils import NestedTensor, Tensor +from axlearn.common.utils import NestedTensor, Tensor, runtime_checks def numpy_log_prob_prefix_alignments( @@ -299,18 +300,17 @@ def test_apply_paddings_check(self): ) log_prob_blank, log_prob_y = jnp.log(prob_blank), jnp.log(prob_y) - # TODO(matthew_e_hopkins): test fails as of jax 0.4.33 through 0.4.35, revisit - # with runtime_checks(): - # with self.assertRaisesRegex( - # jaxlib.xla_extension.XlaRuntimeError, - # "lm_paddings cannot be all 1s.", - # ): - # jax.jit(jax.vmap(apply_paddings))( - # log_prob_blank=log_prob_blank, - # log_prob_y=log_prob_y, - # am_paddings=am_paddings, - # lm_paddings=lm_paddings, - # ) + with self.assertRaisesRegex( + jaxlib.xla_extension.XlaRuntimeError, + "lm_paddings cannot be all 1s.", + ): + with runtime_checks(): + jax.jit(jax.vmap(apply_paddings))( + log_prob_blank=log_prob_blank, + log_prob_y=log_prob_y, + am_paddings=am_paddings, + lm_paddings=lm_paddings, + ) check_apply_paddings = checkify.checkify(apply_paddings, errors=checkify.user_checks) err, _ = jax.jit(jax.vmap(check_apply_paddings))( log_prob_blank=log_prob_blank, diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 9677a6954..d3b864f19 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -141,8 +141,11 @@ def switch(value): jax.config.update("jax_experimental_unsafe_xla_runtime_errors", value) switch(enabled) - yield - switch(old_state) + try: + yield + jax.effects_barrier() + finally: + switch(old_state) @contextlib.contextmanager From 36a938ce2648b0934071168d7c61d82f46c7cfdd Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 22 Oct 2024 17:10:16 +0000 Subject: [PATCH 2/2] use JaxRuntimeError name when available --- axlearn/common/transducer_test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/axlearn/common/transducer_test.py b/axlearn/common/transducer_test.py index 74c6181e3..987722e85 100644 --- a/axlearn/common/transducer_test.py +++ b/axlearn/common/transducer_test.py @@ -300,10 +300,12 @@ def test_apply_paddings_check(self): ) log_prob_blank, log_prob_y = jnp.log(prob_blank), jnp.log(prob_y) - with self.assertRaisesRegex( - jaxlib.xla_extension.XlaRuntimeError, - "lm_paddings cannot be all 1s.", - ): + # TODO(mattjj): replace with jax.errors.JaxRuntimeError when minimum jax + # version is jax>=0.4.35 + cls = getattr(jax.errors, 'JaxRuntimeError', + jaxlib.xla_extension.XlaRuntimeError) + + with self.assertRaisesRegex(cls, "lm_paddings cannot be all 1s."): with runtime_checks(): jax.jit(jax.vmap(apply_paddings))( log_prob_blank=log_prob_blank,