Skip to content

feat: enable jax backend for virtual arrays #3451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

ikrommyd
Copy link
Collaborator

Title explains it. This is a PR we'll need to revert if the jax backend gets dropped one day.

@ikrommyd ikrommyd marked this pull request as draft April 11, 2025 04:11
@ikrommyd
Copy link
Collaborator Author

ikrommyd commented Apr 11, 2025

I've ran the full suite with the jax tests enabled in the https://github.com/scikit-hep/awkward/tree/ikrommyd/test-virtual-arrays branch and I'm only getting these:

FAILED tests/test_1490_jax_reducers_combinations.py::test_bool_raises[all-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1762_jax_behavior_support.py::test_jvp_nested_list - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_numpyarray_grad_3 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_regular_array_3 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_regular_array_4 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_2638_mean_and_count_grads.py::test - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1447_jax_autodiff_slices_ufuncs.py::test_recordarray_7 - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[sum-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[prod0-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[min-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[max-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[mean-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[mean-1] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[mean-None] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[prod1-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[ptp-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[ptp-1] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[ptp-None] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[std-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[std-1] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_reducer[std-None] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_int_output_reducer[argmin-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_int_output_reducer[argmax-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_sort[sort-0] - ValueError: not enough values to unpack (expected 3, got 0)
FAILED tests/test_1490_jax_reducers_combinations.py::test_bool_raises[any-0] - ValueError: not enough values to unpack (expected 3, got 0)

I've checked locally and this is an artifact of the wrapping inside the Index and the NumpyArray that we do. Didn't try to understand why that happens yet due to the wrapping but I've checked the such tests pass locally with TRUE virtual arrays. For example:

import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib

from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length

numpy = Numpy.instance()
jax = Jax.instance()

def generator_offsets():
    print("CALLING OFFSETS GENERATOR!")
    return np.array([0, 3, 3, 5], dtype=np.int64)

def generator_x():
    print("CALLING X GENERATOR!")
    return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)

layout = ak.contents.ListOffsetArray(
    ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
    ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
array = ak.materialize(layout)
val_mean, grad_mean = jaxlib.value_and_grad(ak.mean, argnums=0)(array)
_, grad_sum = jaxlib.value_and_grad(ak.sum, argnums=0)(array)
val_count, grad_count = jaxlib.value_and_grad(ak.count, argnums=0)(array)
assert val_mean == 3
assert ak.all(
    grad_mean == ak.Array([[0.2, 0.2, 0.2], [], [0.2, 0.2]], backend="jax")
)

# mean is treated as scaled sum
assert ak.all(grad_mean == grad_sum / val_count)

assert val_count == 5
assert ak.all(
    grad_count == ak.Array([[0.0, 0.0, 0.0], [], [0.0, 0.0]], backend="jax")
)

In this test, the ak.mean is the issue in the testing branch. The following works totally fine on this PR's branch but raises

ValueError: not enough values to unpack (expected 3, got 0)

This error occurred while calling

    ak.mean(
        <Array [[1.0, 2.0, 3.0], [], [4.0, 5.0]] type='3 * var * float64'>
    )

in the testing branch

import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib

from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length

numpy = Numpy.instance()
jax = Jax.instance()

def generator_offsets():
    print("CALLING OFFSETS GENERATOR!")
    return np.array([0, 3, 3, 5], dtype=np.int64)

def generator_x():
    print("CALLING X GENERATOR!")
    return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)

layout = ak.contents.ListOffsetArray(
    ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
    ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
ak.mean(layout)

@ikrommyd
Copy link
Collaborator Author

@ianna @pfackeldey let me know what you think of the way I'm adding this. I'll probably just add some less extensive testing for jax-virtual arrays.

@ikrommyd
Copy link
Collaborator Author

Needs #3457 and #3464

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant