Skip to content

Fix spurious KeyReuseError under vmap over a zero-sized key array#38174

Open
Lawson-Darrow wants to merge 1 commit into
jax-ml:mainfrom
Lawson-Darrow:fix/key-reuse-vmap-zero-size-37859
Open

Fix spurious KeyReuseError under vmap over a zero-sized key array#38174
Lawson-Darrow wants to merge 1 commit into
jax-ml:mainfrom
Lawson-Darrow:fix/key-reuse-vmap-zero-size-37859

Conversation

@Lawson-Darrow
Copy link
Copy Markdown

Fixes #37859.

When jax_debug_key_reuse is enabled, vmap-ing a function over a zero-sized PRNG key array raised a spurious KeyReuseError, even when the function did no key reuse (e.g. it split its key and consumed each derived sub-key exactly once).

In _SourceSinkBase.__init__, an all-True mask is collapsed to scalar True. But np.all() is vacuously True for an empty array, so a mask coming from a zero-sized vmap axis (which affects no keys) was wrongly collapsed to True, i.e. full consumption, which then tripped the reuse check.

Checking not np.any(mask) before np.all(mask) makes an empty mask collapse to False (a no-op) while leaving every non-empty mask unchanged (for any non-empty mask exactly one of the two branches was already taken). Adds a regression test.

With jax_debug_key_reuse enabled, vmap over a zero-sized PRNG key array
raised a spurious KeyReuseError even when the function did no key reuse.
In _SourceSinkBase.__init__ an all-True mask collapses to scalar True,
but np.all() is vacuously True for an empty array, so a mask from a
zero-sized vmap axis (which affects no keys) wrongly became True, i.e.
full consumption. Check `not np.any(mask)` before `np.all(mask)` so an
empty mask collapses to False, a no-op.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue where vmap over a zero-sized key axis raised a spurious KeyReuseError. It reorders the mask checks in jax/experimental/key_reuse/_core.py so that empty masks (where np.all is vacuously True) are correctly handled first as a no-op. A regression test has also been added to verify this behavior. There are no review comments, so I have no additional feedback to provide.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax_debug_key_reuse raises spurious KeyReuseError under vmap over a 0-sized key array

1 participant