Skip to content

Fix jnp.vectorize crash when None is at index 0#38117

Open
Kymi808 wants to merge 1 commit into
jax-ml:mainfrom
Kymi808:fix/vectorize-none-at-index-zero
Open

Fix jnp.vectorize crash when None is at index 0#38117
Kymi808 wants to merge 1 commit into
jax-ml:mainfrom
Kymi808:fix/vectorize-none-at-index-zero

Conversation

@Kymi808
Copy link
Copy Markdown

@Kymi808 Kymi808 commented Jun 2, 2026

Summary

jnp.vectorize crashes with ValueError: None is not a valid value for jnp.array when a positional None argument is at index 0, but works for any other position.

Root cause

jax/_src/numpy/vectorize.py collects the positions of None arguments into a set and guards the None-handling branch with if any(none_args):. When the only None is at index 0, none_args == {0} and any({0}) is False (the set has one element and that element is 0). The branch that strips None args out is skipped, the next line runs jnp.asarray(None), and that raises.

>>> import jax.numpy as jnp
>>> f = jnp.vectorize(lambda y, x: x if y is None else x + y)
>>> f(None, jnp.arange(10))
ValueError: None is not a valid value for jnp.array

With the same None at any other index, vectorize works as designed (e.g. f(jnp.arange(10), None)).

None support was added in #18441; the regression test (test_none_arg) only exercised None at index 1, which is why this slipped through.

Fix

Replace if any(none_args): with if none_args:. The set's emptiness is what we actually want to test — same intent, different reduction. One-character change.

Test

Extended test_none_arg in tests/lax_numpy_vectorize_test.py to also exercise None at index 0. Verified the fix locally by patching the installed jax and re-running:

  • g(None, x) returns x (previously crashed)
  • f(x, None), f(x, y) continue to work
  • test_none_arg_bad_signature continues to raise the expected ValueError for invalid signatures, including with None at index 0

Duplicate check

Searched open and closed PRs and issues for vectorize none, any(none_args), and vectorize None at index. No prior PR or open issue covers this case; the only related PR (#18441) is the merged one that introduced the bug.

Disclosure

AI-assisted: a code-search agent flagged the suspicious any(set) pattern; the fix and tests were reviewed line-by-line before submission. I will sign the CLA after opening this PR if not already covered.

`vectorize` collected `None` argument positions into a set and
guarded the None-handling branch with `if any(none_args):`. When the
only None is at index 0, `any({0})` evaluates as `any` over a single
falsy element and returns False, so the branch is skipped and the
subsequent `jnp.asarray(None)` raises
`ValueError: None is not a valid value for jnp.array`.

Use a truthiness check on the set itself, and extend the existing
`test_none_arg` regression to cover the index-0 case.
@google-cla
Copy link
Copy Markdown

google-cla Bot commented Jun 2, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in jax.numpy.vectorize where passing None as the first argument (index 0) was not handled correctly because any({0}) evaluated to False. The condition has been corrected to if none_args: to properly check if the set is non-empty, and a regression test has been added to prevent future occurrences. There are no review comments, so I have no feedback to provide.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@Kymi808
Copy link
Copy Markdown
Author

Kymi808 commented Jun 2, 2026

@googlebot I signed it!

1 similar comment
@Kymi808
Copy link
Copy Markdown
Author

Kymi808 commented Jun 2, 2026

@googlebot I signed it!

Copy link
Copy Markdown
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The fix here makes sense – just one comment on the test code.

Once you make the changes, could you please squash all changes into a single commit? Thanks!

y = jnp.arange(10, 20)
self.assertAllClose(f(x, y), x + y)

# Regression test for None at index 0: ``any(none_args)`` evaluated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put this in a separate test.

Also, let's delete this long comment and instead write this:

# Regression test for https://github.com/jax-ml/jax/pull/38117

because the PR/bug has the relevant details.

@jakevdp jakevdp self-assigned this Jun 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants