Skip to content

feat: enable jax backend for virtual arrays #3451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 30, 2025

Conversation

ikrommyd
Copy link
Collaborator

Title explains it. This is a PR we'll need to revert if the jax backend gets dropped one day.

@ikrommyd ikrommyd marked this pull request as draft April 11, 2025 04:11
@ikrommyd
Copy link
Collaborator Author

ikrommyd commented Apr 11, 2025

I've ran the full suite with the jax tests enabled in the https://github.com/scikit-hep/awkward/tree/ikrommyd/test-virtual-arrays branch and I'm only getting these:

FAILED tests/test_1490_jax_reducers_combinations.py::test_bool_raises[all-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1762_jax_behavior_support.py::test_jvp_nested_list - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_numpyarray_grad_3 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_regular_array_3 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_regular_array_4 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_2638_mean_and_count_grads.py::test - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_recordarray_7 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[sum-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[prod0-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[min-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[max-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[mean-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[mean-1] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[mean-None] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[prod1-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[ptp-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[ptp-1] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[ptp-None] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[std-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[std-1] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[std-None] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_int_output_reducer[argmin-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_int_output_reducer[argmax-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_sort[sort-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_bool_raises[any-0] - ValueError: not enough values to unpack (expected 3, got 0)

I've checked locally and this is an artifact of the wrapping inside the Index and the NumpyArray that we do. Didn't try to understand why that happens yet due to the wrapping but I've checked the such tests pass locally with TRUE virtual arrays. For example:

import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib

from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length

numpy = Numpy.instance()
jax = Jax.instance()

def generator_offsets():
    print("CALLING OFFSETS GENERATOR!")
    return np.array([0, 3, 3, 5], dtype=np.int64)

def generator_x():
    print("CALLING X GENERATOR!")
    return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)

layout = ak.contents.ListOffsetArray(
    ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
    ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
array = ak.materialize(layout)
val_mean, grad_mean = jaxlib.value_and_grad(ak.mean, argnums=0)(array)
_, grad_sum = jaxlib.value_and_grad(ak.sum, argnums=0)(array)
val_count, grad_count = jaxlib.value_and_grad(ak.count, argnums=0)(array)
assert val_mean == 3
assert ak.all(
    grad_mean == ak.Array([[0.2, 0.2, 0.2], [], [0.2, 0.2]], backend="jax")
)

# mean is treated as scaled sum
assert ak.all(grad_mean == grad_sum / val_count)

assert val_count == 5
assert ak.all(
    grad_count == ak.Array([[0.0, 0.0, 0.0], [], [0.0, 0.0]], backend="jax")
)

In this test, the ak.mean is the issue in the testing branch. The following works totally fine on this PR's branch but raises

ValueError: not enough values to unpack (expected 3, got 0)

This error occurred while calling

    ak.mean(
        <Array [[1.0, 2.0, 3.0], [], [4.0, 5.0]] type='3 * var * float64'>
    )

in the testing branch

import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib

from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length

numpy = Numpy.instance()
jax = Jax.instance()

def generator_offsets():
    print("CALLING OFFSETS GENERATOR!")
    return np.array([0, 3, 3, 5], dtype=np.int64)

def generator_x():
    print("CALLING X GENERATOR!")
    return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)

layout = ak.contents.ListOffsetArray(
    ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
    ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
ak.mean(layout)

@ikrommyd
Copy link
Collaborator Author

@ianna @pfackeldey let me know what you think of the way I'm adding this. I'll probably just add some less extensive testing for jax-virtual arrays.

@ikrommyd
Copy link
Collaborator Author

Needs #3457 and #3464

@ikrommyd ikrommyd marked this pull request as ready for review April 28, 2025 14:32
@ikrommyd ikrommyd requested review from pfackeldey and ianna April 28, 2025 14:32
@ikrommyd ikrommyd linked an issue Apr 28, 2025 that may be closed by this pull request
Copy link
Collaborator

@pfackeldey pfackeldey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Copy link
Collaborator

@ianna ianna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ikrommyd - Looks great! Thanks! Please, merge it if you are done with it. Thanks.

@ianna ianna merged commit 6df5fd9 into main Apr 30, 2025
42 checks passed
@ianna ianna deleted the ikrommyd/enable-jax-backend-for-virtualarray branch April 30, 2025 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Should VirtualArrays allow JAX as nplike?
3 participants