Skip to content

[tx] Implement efficient kernel for ragged_dot that supports expert parallelism #862

@pcmoritz

Description

@pcmoritz

We now have support for expert parallelism with #842. Currently this is implemented by running ragged_dot on the subset of local experts. In order to implement this efficiently with JAX JIT (i.e. static shapes), we need the group_offset parameter of ragged_dot so we can target only the tokens that got routed to the local experts. However, currently group_offset is not implemented in jax.lax.ragged_dot (see also the discussion in jax-ml/jax#34168 which is worth reading). Therefore we implemented a workaround for the time being in #860 that makes it possible to run the code and is actually surprisingly efficient already. However, it can be optimized by not running expert computations on the extra tokens.

There are several ways to do this:

  1. Implement group_offset for jax.lax.ragged_dot, we are most interested in the GPU case for now, but later it will also be good to do it for TPUs (and for now we can use the fallback code there). There are some pointers on how to do this in Expert parallelism and "Unimplemented group_offset support" for jax.lax.ragged_dot jax-ml/jax#34168 (comment) -- it is conceptually pretty simple but will need some understanding of and wrangling with XLA. It has the advantage that we would likely not need to do extra auto-tuning on top of what XLA already does. Would need some discussions / sync with the JAX/XLA teams, but I think a prototype could be written with the code that is already open-source.
  2. (Likely the simplest option) Implement ragged_dot that supports group_offset via pallas. There are several implementation we could adapt, like
  1. Integrate an existing kernel from the CUDA ecosystem, like https://github.com/NVIDIA/cutile-python/blob/main/samples/MoE.py (currently only supports blackwell, but likely cuda-tile will support older architectures too going forward Ampere and mma.sp support NVIDIA/cuda-tile#8). We could probably also integrate a triton kernel (https://github.com/jax-ml/jax-triton), or call any other CUDA library.

Ideally the implementation would be simple yet performant and supports as many platforms (e.g. hopper, blackwell, older gpu architectures, TPUs) as possible, but it might be hard to satisfy all the requirements, so partial progress in any of these dimensions (and using the current fallback elsewhere) is very welcome.

Improving the performance here (and possibly also improving the performance over vanilla jax.lax.ragged_dot without group_offset will have a huge impact because that kernel is used so much, both for MultiLoRA support as well as expert handling.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions