-
Notifications
You must be signed in to change notification settings - Fork 414
Description
Hello,
I'm having trouble when using my warp kernels together with jax.vmap. I tried using it as explained in the documentation but the kernel is only ever launched across a single dimension.
Minimal reproducer:
import warp as wp
import jax
import jax.numpy as jnp
from warp.jax_experimental.ffi import jax_kernel
wp.init()
# 1. The Warp Kernel
@wp.kernel
def scalar_add_kernel(
data_io: wp.array(dtype=wp.float32, ndim=1),
val: float
):
wp.printf("Thread %d processing data_io with shape %d\n", wp.tid(), data_io.shape[0])
for i in range(data_io.shape[0]):
data_io[i] = data_io[i] + val
# 2. The JAX Interface
def get_jax_kernel_fn():
return jax_kernel(
scalar_add_kernel,
launch_dims=(1,),
in_out_argnames=['data_io'],
vmap_method='broadcast_all'
)
# 3. Usage with vmap
def batch_process(batch_array, value_to_add, axis_to_vmap):
kernel_fn = get_jax_kernel_fn()
vmapped_fn = jax.vmap(kernel_fn, in_axes=(axis_to_vmap, None))
return vmapped_fn(batch_array, value_to_add)
# 4. Run Test
data = jnp.zeros((3, 5), dtype=jnp.float32)
print("Input (3x5):\n", data)
result = jax.jit(batch_process, static_argnums=(1, 2))(data, 10.0, 0)
print("\nOutput (should be all 10s):\n", result)
result_ = jax.jit(batch_process, static_argnums=(1, 2))(data, 10.0, 1)
print("\nOutput with different striding (should be all 10s):\n", result_)
output:
Input (3x5):
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
Module main 057841c load on device 'cuda:0' took 433.62 ms (compiled)
Thread 0 processing data_io with shape 3
Output (should be all 10s):
[Array([[10., 10., 10., 0., 0.],
[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.]], dtype=float32)]
Output with different striding (should be all 10s):
Thread 0 processing data_io with shape 5
[Array([[10., 10., 10.],
[10., 10., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]], dtype=float32)]
Could you please point out the error in my code?