-
Notifications
You must be signed in to change notification settings - Fork 220
Description
We now have support for expert parallelism with #842. Currently this is implemented by running ragged_dot on the subset of local experts. In order to implement this efficiently with JAX JIT (i.e. static shapes), we need the group_offset parameter of ragged_dot so we can target only the tokens that got routed to the local experts. However, currently group_offset is not implemented in jax.lax.ragged_dot (see also the discussion in jax-ml/jax#34168 which is worth reading). Therefore we implemented a workaround for the time being in #860 that makes it possible to run the code and is actually surprisingly efficient already. However, it can be optimized by not running expert computations on the extra tokens.
There are several ways to do this:
- Implement
group_offsetforjax.lax.ragged_dot, we are most interested in the GPU case for now, but later it will also be good to do it for TPUs (and for now we can use the fallback code there). There are some pointers on how to do this in Expert parallelism and "Unimplemented group_offset support" for jax.lax.ragged_dot jax-ml/jax#34168 (comment) -- it is conceptually pretty simple but will need some understanding of and wrangling with XLA. It has the advantage that we would likely not need to do extra auto-tuning on top of what XLA already does. Would need some discussions / sync with the JAX/XLA teams, but I think a prototype could be written with the code that is already open-source. - (Likely the simplest option) Implement
ragged_dotthat supportsgroup_offsetvia pallas. There are several implementation we could adapt, like
- https://github.com/rdyro/gpu_ragged_dot/blob/main/gpu_ragged_dot.py (this is a very simple implementation but also shows how to support forward and backward)
- the megablox kernel in https://github.com/AI-Hypercomputer/maxtext, which I believe is also mirrored in https://github.com/sgl-project/sglang-jax/blob/main/python/sgl_jax/srt/kernels/gmm/megablox_gmm_kernel/gmm.py
- the "official" pallas kernels (https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py, https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/blackwell_ragged_dot_mgpu.py)
- the kernels in https://github.com/openxla/tokamax
- Integrate an existing kernel from the CUDA ecosystem, like https://github.com/NVIDIA/cutile-python/blob/main/samples/MoE.py (currently only supports blackwell, but likely cuda-tile will support older architectures too going forward Ampere and mma.sp support NVIDIA/cuda-tile#8). We could probably also integrate a triton kernel (https://github.com/jax-ml/jax-triton), or call any other CUDA library.
Ideally the implementation would be simple yet performant and supports as many platforms (e.g. hopper, blackwell, older gpu architectures, TPUs) as possible, but it might be hard to satisfy all the requirements, so partial progress in any of these dimensions (and using the current fallback elsewhere) is very welcome.
Improving the performance here (and possibly also improving the performance over vanilla jax.lax.ragged_dot without group_offset will have a huge impact because that kernel is used so much, both for MultiLoRA support as well as expert handling.