-
Notifications
You must be signed in to change notification settings - Fork 90
feat: enable jax backend for virtual arrays #3451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I've ran the full suite with the jax tests enabled in the https://github.com/scikit-hep/awkward/tree/ikrommyd/test-virtual-arrays branch and I'm only getting these:
I've checked locally and this is an artifact of the wrapping inside the import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib
from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length
numpy = Numpy.instance()
jax = Jax.instance()
def generator_offsets():
print("CALLING OFFSETS GENERATOR!")
return np.array([0, 3, 3, 5], dtype=np.int64)
def generator_x():
print("CALLING X GENERATOR!")
return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)
layout = ak.contents.ListOffsetArray(
ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
array = ak.materialize(layout)
val_mean, grad_mean = jaxlib.value_and_grad(ak.mean, argnums=0)(array)
_, grad_sum = jaxlib.value_and_grad(ak.sum, argnums=0)(array)
val_count, grad_count = jaxlib.value_and_grad(ak.count, argnums=0)(array)
assert val_mean == 3
assert ak.all(
grad_mean == ak.Array([[0.2, 0.2, 0.2], [], [0.2, 0.2]], backend="jax")
)
# mean is treated as scaled sum
assert ak.all(grad_mean == grad_sum / val_count)
assert val_count == 5
assert ak.all(
grad_count == ak.Array([[0.0, 0.0, 0.0], [], [0.0, 0.0]], backend="jax")
) In this test, the
in the testing branch import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib
from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length
numpy = Numpy.instance()
jax = Jax.instance()
def generator_offsets():
print("CALLING OFFSETS GENERATOR!")
return np.array([0, 3, 3, 5], dtype=np.int64)
def generator_x():
print("CALLING X GENERATOR!")
return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)
layout = ak.contents.ListOffsetArray(
ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
ak.mean(layout) |
@ianna @pfackeldey let me know what you think of the way I'm adding this. I'll probably just add some less extensive testing for jax-virtual arrays. |
Title explains it. This is a PR we'll need to revert if the jax backend gets dropped one day.