Context Parallel of Linear Atttention (alias KCP in Moonshot) is context parallelism designed for delta-rule recurrent models such as GDN (Gated Delta Rule) and KDA (Kimi Delta Attention). It enables efficient distributed training by partitioning the sequence dimension across ranks, with each rank processing a local token chunk and CP automatically synchronizing cross-rank states.
from fla.ops.cp import build_cp_context
# global cu_seqlens before partition (device can be CPU or GPU)
cu_seqlens_global = torch.tensor(
[0, s1, s1 + s2, ..., total],
dtype=torch.long,
device=device
)
# conv1d_kernel_size is required for causal_conv1d CP path
cp_context = build_cp_context(
cu_seqlens_global,
group=dist.group.WORLD,
conv1d_kernel_size=W,
)from fla.modules.convolution import causal_conv1d
# x_local is the rank-local chunk: [1, T_local, D]
y_local, _ = causal_conv1d(
x=x_local,
weight=weight_local,
bias=bias_local,
activation="swish",
cp_context=cp_context,
)Note
cp_contextis required;cp_context.conv1d_kernel_sizeandcp_context.cu_seqlensmust be set.- Do not pass
cu_seqlens/cu_seqlens_cpumanually — they are taken from context.
from fla.ops.kda import chunk_kda
o_local, _ = chunk_kda(
q=F.normalize(q_local, p=2, dim=-1),
k=F.normalize(k_local, p=2, dim=-1),
v=v_local,
g=g_local,
beta=beta_local,
cp_context=cp_context,
disable_recompute=disable_recompute,
)Note
- CP expects
B == 1for varlen and uses rank-localcu_seqlensfrom context. initial_stateandoutput_final_state=Trueare not supported in CP mode.
CP context stores rank-local varlen metadata that tracks how sequences are distributed:
FLACPContext.cu_seqlens— rank-local cumulative sequence lengths, on GPU (int64)FLACPContext.cu_seqlens_cpu— same data on CPU for host-side indexing
Variable-length inputs start as global cu_seqlens before partitioning; build_cp_context converts them into rank-local metadata automatically.
We follow the notation from the Kimi Linear technical report (Section 2.1). Throughout this document, subscript
Vectors and matrices:
-
$\boldsymbol{q}_t, \boldsymbol{k}_t, \boldsymbol{v}_t, \boldsymbol{o}_t, \boldsymbol{u}_t, \boldsymbol{w}_t$ — column vectors in$\mathbb{R}^{d_k}$ or$\mathbb{R}^{d_v}$ at position$t$ -
$\mathbf{S}_t \in \mathbb{R}^{d_k \times d_v}$ — matrix-form memory state; FLA kernels store as[d_k, d_v], some backends transpose to[d_v, d_k] -
$\mathbf{X}$ with subscript$[t]$ — stacked vectors within chunk$t$ (shape$C \times d$ ); sequence length$L$ splits into$L/C$ chunks of size$C$ -
$\boldsymbol{x}^r$ with subscript$[t]$ — the$r$ -th element in chunk$t$ , i.e.,$\boldsymbol{x}_{tC+r}$ where$t \in [0, L/C), r \in [1, C]$
State and decay:
The decay factor
Code mapping:
-
gstores$\log(\alpha)$ (or$\log_2(\alpha)$ for KDA) - After
chunk_local_cumsum,gat position$r$ equals$\log \gamma^r$ - Then
$\exp(g) = \gamma$ and$\exp(g_{\mathrm{last}} - g_r) = \gamma^{r \to C}$
Both GDN and KDA are built on the delta rule — a recurrent update where the state matrix
GDN uses a single scalar gate per head per token. From [Yang et al., 2025]:
KDA extends GDN with a per-dimension gate, giving finer control over which features to retain or forget. From Eq. 1 in the Kimi-k1.5 report:
For efficiency, we process tokens in chunks using the WY representation (Eq. 7 in the report), which computes auxiliary matrices
This formulation is key to CP: it lets us compute how the state transforms across a chunk, enabling efficient cross-rank synchronization.
While both models share the delta rule structure, they differ in how gating is applied — a distinction that affects the CP implementation.
GDN's scalar gate is cheap to apply inside kernels, so we pass the original tensors and let the kernel handle gating internally:
-
Gate:
$\alpha_t \in [0,1]$ , one scalar per head per token -
Code:
gshape[B, T, H]where$\alpha = \exp(g)$ ; processed bychunk_local_cumsum -
Kernel input: Original
$\boldsymbol{k}$ ,$\boldsymbol{q}$ , and scalarg -
Internal gating (
USE_G=True):- Inter-chunk decay:
$\mathbf{S} \leftarrow \gamma^C \cdot \mathbf{S}$ (scalar broadcast) - Gated key:
$\tilde{\boldsymbol{k}}^r = \boldsymbol{k}^r \cdot \gamma^{r \to C}$ - Gated query:
$\tilde{\boldsymbol{q}}^r = \boldsymbol{q}^r \cdot \gamma^r$ (backward only)
- Inter-chunk decay:
KDA's per-dim gate
-
Gate:
$\alpha_t \in [0,1]^{d_k}$ , one value per dimension per token -
Code:
gshape[B, T, H, K]where$\alpha = \exp_2(g)$ ; processed bykda_gate_chunk_cumsum -
Pre-gated tensors (from
chunk_kda_fwd_intra/recompute_w_u_fwd):-
kg: row$r$ is$\boldsymbol{k}^r \odot \gamma^{r \to C}$ , i.e.,k * exp2(gk_last - gk) -
qg: row$r$ is$\boldsymbol{q}^r \odot \gamma^{r}$ , i.e.,q * exp2(gk)(saved for backward)
-
-
Kernel input: Pre-gated
kg(andqgin backward), plusgk=gfor inter-chunk decay -
Kernel gating (
USE_GK=True): Only chunk-level decay$\mathbf{S} \leftarrow \mathrm{Diag}(\gamma^C) \mathbf{S}$
This design means CP pre-processing must use the same tensors as the main kernel — original for GDN, pre-gated for KDA.
The core challenge of CP is that each rank only sees a local chunk, but the recurrent state depends on all previous tokens. We solve this with an all-gather + merge pattern:
-
Local computation: Each rank computes
$(\mathbf{S}_\text{ext}, \mathbf{M})$ from its chunk-
$\mathbf{S}_\text{ext} \in \mathbb{R}^{d_k \times d_v}$ : accumulated state assuming$\mathbf{S}_0 = \mathbf{0}$ -
$\mathbf{M} \in \mathbb{R}^{d_k \times d_k}$ : transition matrix capturing how the chunk transforms incoming state
-
-
All-gather: Collect
$[\mathbf{S}_\text{ext}, \mathbf{M}]$ from all ranks -
Merge: Rank
$r$ reconstructs its initial state by chaining contributions from ranks$< r$ :
This step computes
Stage 1 — Accumulated state
We simulate processing the chunk with zero initial state. Initialize
Stage 2 — Transition matrix
The transition matrix captures how incoming state is transformed. Initialize
Merge (forward direction):
For rank pre_num_ranks previous ranks:
The backward pass has the same structure but reversed direction — we merge from ranks after the current rank to propagate gradients backward through the sequence.
Stage 1 — Gradient
Initialize
where
Stage 2 — Gradient
Initialize
Note
Merge (backward direction):
For rank post_num_ranks following ranks:
The following examples show how CP integrates with the existing kernel interfaces.
g = chunk_local_cumsum(g, chunk_size=64)
w, u = recompute_w_u_fwd(k, v, beta, A, g=g)
# CP pre-process: original k, scalar g
initial_state = chunk_gated_delta_rule_fwd_h_pre_process(
k=k, w=w, u=u, g=g, # USE_G=True, USE_GK=False
context=cp_context,
)
# Main kernel: original k, scalar g
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=k, w=w, u=u, g=g,
initial_state=initial_state,
)w, u = recompute_w_u_fwd(k, v, beta, A, g=g)
h, v_new, _ = chunk_gated_delta_rule_fwd_h(k=k, w=w, u=u, g=g, ...)
dv = chunk_bwd_dv_local(q=q, k=k, g=g, do=do, ...)
# CP pre-process: original q, k, scalar g
dht, initial_state = chunk_gated_delta_rule_bwd_dhu_pre_process(
q=q, k=k, w=w, do=do, dv=dv, g=g, # USE_G=True, USE_GK=False
context=cp_context,
)
# Main kernel: original q, k, scalar g
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=q, k=k, w=w, g=g,
dht=dht, ...
)# 1. Intra-chunk: compute WY repr + pre-gated tensors
w, u, qg, kg, Aqk, Akk = chunk_kda_fwd_intra(q, k, v, gk=g, beta, ...)
# kg = K ⊙ exp2(γ^{r→C}), i.e., rows of Γ^{i→C} ⊙ K
# qg = Q ⊙ exp2(γ^r), i.e., rows of Γ^{1→C} ⊙ Q (saved for backward)
# 2. CP pre-process: pre-gated kg, per-dim gk=g
initial_state = chunk_gated_delta_rule_fwd_h_pre_process(
k=kg, w=w, u=u, gk=g, # USE_G=False, USE_GK=True, use_exp2=True
context=cp_context,
)
# 3. Main kernel: pre-gated kg, per-dim gk=g
h, v_new, _ = chunk_gated_delta_rule_fwd_h(
k=kg, w=w, u=u, gk=g,
initial_state=initial_state,
use_exp2=True,
)# 1. Recompute WY repr
w, u, qg, kg = recompute_w_u_fwd(q, k, v, beta, A=Akk, gk=g, ...)
# qg = Q ⊙ exp2(γ^r), kg = K ⊙ exp2(γ^{r→C})
# 2. Recompute state
h, v_new, _ = chunk_gated_delta_rule_fwd_h(k=kg, w=w, u=u, gk=g, ...)
# 3. Compute local dv
dAqk, dv = chunk_kda_bwd_dAv(q, k, v=v_new, do, A=Aqk, ...)
# 4. CP pre-process: pre-gated qg, kg, per-dim gk=g
dht, initial_state = chunk_gated_delta_rule_bwd_dhu_pre_process(
q=qg, k=kg, w=w, do=do, dv=dv, gk=g, # USE_G=False, USE_GK=True, use_exp2=True
context=cp_context,
)
# 5. Main kernel: pre-gated qg, kg
dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
q=qg, k=kg, w=w, gk=g,
dht=dht, ...
use_exp2=True,
)| Function | GDN | KDA | Gate Path |
|---|---|---|---|
pre_process_fwd |
k=k, g=g |
k=kg, gk=g |
GDN: USE_G, KDA: USE_GK |
fwd_h |
k=k, g=g |
k=kg, gk=g |
Same as pre_process |
pre_process_bwd |
q=q, k=k, g=g |
q=qg, k=kg, gk=g |
GDN: USE_G, KDA: USE_GK |
bwd_dhu |
q=q, k=k, g=g |
q=qg, k=kg, gk=g |
Same as pre_process |
Key consistency: Pre-process and main kernel must always receive the same tensors — this is critical for correctness.
-
KDA: Both receive pre-gated
kg($\boldsymbol{\Gamma}^{i \to C} \odot \mathbf{K}$ ) andqg($\boldsymbol{\Gamma}^{1 \to C} \odot \mathbf{Q}$ ) -
GDN: Both receive original
$\boldsymbol{k}$ ,$\boldsymbol{q}$ (gating applied inside the kernel)
The transition matrix
Forward:
Backward (transposed):
The diagonal term
-
GDN:
$\exp(g_{\mathrm{last}}) \cdot \mathbf{I}$ — scalar times identity -
KDA:
$\mathrm{Diag}(\gamma^C)$ — per-dim diagonal, where$\gamma^C = \exp_2(g_{\mathrm{last}})$ (i.e.,gk_lastin code)
Cross-rank state is computed by chaining
Important
The
In CP mode, only the first sequence in the local batch can be a continuation from a previous rank — all other sequences start fresh. This means only one initial state
compress_h0: Extracts just that one state to save memory duringsave_for_backwardexpand_h0: Restores the full[N, H, d_k, d_v]tensor in backward
While this document focuses on delta-rule models such as GDN and KDA, the underlying CP mechanism is not restricted to delta-rule recurrences. In fact, any linear attention formulation that can be expressed in a chunkwise form — i.e., one where the state transition across a chunk can be decomposed into a transition matrix
The only model-specific components are:
- How
$\mathbf{M}$ and$\mathbf{S}_\text{ext}$ are computed from the local chunk. - How the merge kernel chains these quantities across ranks.
As long as these two operations are well-defined, the same CP infrastructure (build_cp_context, all-gather, and merge) applies without changing the high-level data flow.
At the time of writing, CP has been implemented and verified for GDN, KDA, and DPLR (a.k.a. RWKV-7). If you would like to see support for another linear-attention variant, please feel free to open an issue.
Context Parallel of Linear Attention was first introduced in PR #691, implemented by Duyue MA. It is also known as KCP (Kimi Context Parallel) internally at Moonshot AI. The implementation in this repository was independently contributed to FLA and is a separate codebase from the internal Moonshot implementation.