-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Motivation
While working with Devstral 2 and more recently, GLM4V, I notice that both the RoPE kernels and the API ggml_rope_ext is quite difficult to work with. I also spotted some opportunities to improve a bit the performance, so I'm writing this proposal
Problems with the existing kernel:
- It does some redundant calculations repeatedly, for example:
theta_scale, or themscalecalculation in yarn code path. They can be calculated once on the host device then pass the result into the kernel - M-RoPE code paths is difficult to track (it got worse after interleaved M-RoPE was added along side with Qwen3-VL)
Problems with the API:
ggml_rope_exthas too many argumentsggml_rope_multihas even more arguments and half of them are hardly used in practicemodewas supposed to be a bit field, but got messy with M-RoPEattn_factoris mis-aligned with transformers which made the recent Devstral 2 debugging quite tricky
Proposal
My proposal contains 2 parts: The API and the kernel
For the API, I propose an array of calls like this:
// rope Qcur with pos
struct ggml_tensor * roped = ggml_rope_v2(ctx, Qcur, pos, n_dims, freq_base);
// with yarn
ggml_rope_v2_set_yarn(ctx, roped, ext_factor, mscale, beta_fast, beta_slow);
// with m-rope
ggml_rope_v2_set_mrope(ctx, roped, sections);
// with ordering (NEOX or NORMAL)
ggml_rope_v2_set_ordering(ctx, roped, GGML_ROPE_ORDERING_NEOX);Behind the scene, these call will pre-calculate as many things as possible before passing args to the kernel. For example theta_scale = powf(freq_base, -2.0f/n_dims) will be calculated inside ggml_rope_v2's implementation
For the kernel, we can statically compile templated kernels with the combination of:
- 2 x direction: forward or backward
- 2 x types: f16, f32
2 x ordering: normal, neox--> the indexing will be controlled via an input arg- 2 x modes: normal, M-RoPE
Note: mode vision (aka 2D-RoPE with Neox) will become a composed op of 2 x ggml_rope_v2 instead
So in total, we will have 8 statically compiled kernels