Skip to content

[QUESTION] How to interface with jax vmap correctly? #1161

@gd193

Description

@gd193

Hello,

I'm having trouble when using my warp kernels together with jax.vmap. I tried using it as explained in the documentation but the kernel is only ever launched across a single dimension.
Minimal reproducer:

import warp as wp
import jax
import jax.numpy as jnp
from warp.jax_experimental.ffi import jax_kernel

wp.init()

# 1. The Warp Kernel

@wp.kernel
def scalar_add_kernel(
    data_io: wp.array(dtype=wp.float32, ndim=1),
    val: float
):
    wp.printf("Thread %d processing data_io with shape %d\n", wp.tid(), data_io.shape[0])
    for i in range(data_io.shape[0]):
        data_io[i] = data_io[i] + val

# 2. The JAX Interface

def get_jax_kernel_fn():
    return jax_kernel(
        scalar_add_kernel,
        launch_dims=(1,),  
        in_out_argnames=['data_io'],
        vmap_method='broadcast_all'
    )

# 3. Usage with vmap

def batch_process(batch_array, value_to_add, axis_to_vmap):
    kernel_fn = get_jax_kernel_fn()

    vmapped_fn = jax.vmap(kernel_fn, in_axes=(axis_to_vmap, None))
    
    return vmapped_fn(batch_array, value_to_add)

# 4. Run Test
data = jnp.zeros((3, 5), dtype=jnp.float32)

print("Input (3x5):\n", data)

result = jax.jit(batch_process, static_argnums=(1, 2))(data, 10.0, 0)

print("\nOutput (should be all 10s):\n", result)

result_ = jax.jit(batch_process, static_argnums=(1, 2))(data, 10.0, 1)

print("\nOutput with different striding (should be all 10s):\n", result_)

output:
Input (3x5):
[[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]
Module main 057841c load on device 'cuda:0' took 433.62 ms (compiled)
Thread 0 processing data_io with shape 3

Output (should be all 10s):
[Array([[10., 10., 10., 0., 0.],
[ 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0.]], dtype=float32)]

Output with different striding (should be all 10s):
Thread 0 processing data_io with shape 5
[Array([[10., 10., 10.],
[10., 10., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]], dtype=float32)]

Could you please point out the error in my code?

Metadata

Metadata

Assignees

Labels

questionThe issue author requires information

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions