Fix jnp.vectorize crash when None is at index 0#38117
Conversation
`vectorize` collected `None` argument positions into a set and
guarded the None-handling branch with `if any(none_args):`. When the
only None is at index 0, `any({0})` evaluates as `any` over a single
falsy element and returns False, so the branch is skipped and the
subsequent `jnp.asarray(None)` raises
`ValueError: None is not a valid value for jnp.array`.
Use a truthiness check on the set itself, and extend the existing
`test_none_arg` regression to cover the index-0 case.
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Code Review
This pull request fixes a bug in jax.numpy.vectorize where passing None as the first argument (index 0) was not handled correctly because any({0}) evaluated to False. The condition has been corrected to if none_args: to properly check if the set is non-empty, and a regression test has been added to prevent future occurrences. There are no review comments, so I have no feedback to provide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
@googlebot I signed it! |
1 similar comment
|
@googlebot I signed it! |
jakevdp
left a comment
There was a problem hiding this comment.
Thanks! The fix here makes sense – just one comment on the test code.
Once you make the changes, could you please squash all changes into a single commit? Thanks!
| y = jnp.arange(10, 20) | ||
| self.assertAllClose(f(x, y), x + y) | ||
|
|
||
| # Regression test for None at index 0: ``any(none_args)`` evaluated |
There was a problem hiding this comment.
Please put this in a separate test.
Also, let's delete this long comment and instead write this:
# Regression test for https://github.com/jax-ml/jax/pull/38117
because the PR/bug has the relevant details.
Summary
jnp.vectorizecrashes withValueError: None is not a valid value for jnp.arraywhen a positionalNoneargument is at index 0, but works for any other position.Root cause
jax/_src/numpy/vectorize.pycollects the positions ofNonearguments into a set and guards the None-handling branch withif any(none_args):. When the onlyNoneis at index 0,none_args == {0}andany({0})isFalse(the set has one element and that element is0). The branch that stripsNoneargs out is skipped, the next line runsjnp.asarray(None), and that raises.With the same
Noneat any other index,vectorizeworks as designed (e.g.f(jnp.arange(10), None)).Nonesupport was added in #18441; the regression test (test_none_arg) only exercisedNoneat index 1, which is why this slipped through.Fix
Replace
if any(none_args):withif none_args:. The set's emptiness is what we actually want to test — same intent, different reduction. One-character change.Test
Extended
test_none_argintests/lax_numpy_vectorize_test.pyto also exerciseNoneat index 0. Verified the fix locally by patching the installed jax and re-running:g(None, x)returnsx(previously crashed)f(x, None),f(x, y)continue to worktest_none_arg_bad_signaturecontinues to raise the expectedValueErrorfor invalid signatures, including withNoneat index 0Duplicate check
Searched open and closed PRs and issues for
vectorize none,any(none_args), andvectorize None at index. No prior PR or open issue covers this case; the only related PR (#18441) is the merged one that introduced the bug.Disclosure
AI-assisted: a code-search agent flagged the suspicious
any(set)pattern; the fix and tests were reviewed line-by-line before submission. I will sign the CLA after opening this PR if not already covered.