-
Notifications
You must be signed in to change notification settings - Fork 94
feat: enable jax backend for virtual arrays #3451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I've ran the full suite with the jax tests enabled in the https://github.com/scikit-hep/awkward/tree/ikrommyd/test-virtual-arrays branch and I'm only getting these:
I've checked locally and this is an artifact of the wrapping inside the import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib
from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length
numpy = Numpy.instance()
jax = Jax.instance()
def generator_offsets():
print("CALLING OFFSETS GENERATOR!")
return np.array([0, 3, 3, 5], dtype=np.int64)
def generator_x():
print("CALLING X GENERATOR!")
return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)
layout = ak.contents.ListOffsetArray(
ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
array = ak.materialize(layout)
val_mean, grad_mean = jaxlib.value_and_grad(ak.mean, argnums=0)(array)
_, grad_sum = jaxlib.value_and_grad(ak.sum, argnums=0)(array)
val_count, grad_count = jaxlib.value_and_grad(ak.count, argnums=0)(array)
assert val_mean == 3
assert ak.all(
grad_mean == ak.Array([[0.2, 0.2, 0.2], [], [0.2, 0.2]], backend="jax")
)
# mean is treated as scaled sum
assert ak.all(grad_mean == grad_sum / val_count)
assert val_count == 5
assert ak.all(
grad_count == ak.Array([[0.0, 0.0, 0.0], [], [0.0, 0.0]], backend="jax")
) In this test, the
in the testing branch import awkward as ak
ak.jax.register_and_check()
import numpy as np
import jax.numpy as jnp
import jax as jaxlib
from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.jax import Jax
from awkward._nplikes.shape import unknown_length
numpy = Numpy.instance()
jax = Jax.instance()
def generator_offsets():
print("CALLING OFFSETS GENERATOR!")
return np.array([0, 3, 3, 5], dtype=np.int64)
def generator_x():
print("CALLING X GENERATOR!")
return jnp.array([1., 2., 3., 4., 5.], dtype=np.float64)
layout = ak.contents.ListOffsetArray(
ak.index.Index(VirtualArray(numpy, (4,), np.int64, generator_offsets)),
ak.contents.NumpyArray(VirtualArray(jax, (5,), np.float64, generator_x)),
)
ak.mean(layout) |
@ianna @pfackeldey let me know what you think of the way I'm adding this. I'll probably just add some less extensive testing for jax-virtual arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ikrommyd - Looks great! Thanks! Please, merge it if you are done with it. Thanks.
Title explains it. This is a PR we'll need to revert if the jax backend gets dropped one day.