Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I believe this PR resolves issue #2489.
The solution is to perform code branching using the conditional control flow jax.lax.cond instead of jp.where.
From what I understand, jax.lax.cond only evaluates the branch that is used, whereas jp.where evaluates both the used and the unused branches. Consequently, if the unused branch has an NaN, this will cause a problem for jp.where (even though the branch is not executed) but not for jax.lax.cond.
@Balint-H noted that the NaNs encountered in issue #2489 go away if tausmooth is set to a nonzero value. I believe this is because the computation of tau_smooth involves dividing by smoothing_width, which gives NaN if smoothing_width = 0 (again, even if the smooth switching branch is not being executed).
If I perform branching with jax.lax.cond (as per this PR), the NaNs go away, even with tausmooth = 0.
To get the MWE provided in issue #2489 to work, I also had to replace jp.where with jax.lax.cond here, again because it involves dividing by 0. I have included this change in my PR. More generally, there are quite a lot of jp.where calls in the mjx codebase, and I wonder if these should also be changed to jax.lax.cond?