diff --git a/python/sgl_jax/srt/kernels/gdn/__init__.py b/python/sgl_jax/srt/kernels/gdn/__init__.py new file mode 100644 index 0000000000..cb7d0d1588 --- /dev/null +++ b/python/sgl_jax/srt/kernels/gdn/__init__.py @@ -0,0 +1,25 @@ +"""Gated DeltaNet (GDN) reference kernels. + +Public entry points: + +* :func:`ragged_gated_delta_rule_ref` — token-by-token ``lax.scan`` over a + packed ragged batch (extend / chunked-prefill). +* :func:`decode_gated_delta_rule_ref` — parallel single-step recurrence + across the batch (decode fast path). +* :func:`jax_causal_conv1d_prefill` / :func:`jax_causal_conv1d_update` — + depthwise causal conv1d helpers (ragged prefill + single-token decode). +""" + +from sgl_jax.srt.kernels.gdn.gated_delta import ( + decode_gated_delta_rule_ref, + jax_causal_conv1d_prefill, + jax_causal_conv1d_update, + ragged_gated_delta_rule_ref, +) + +__all__ = [ + "decode_gated_delta_rule_ref", + "jax_causal_conv1d_prefill", + "jax_causal_conv1d_update", + "ragged_gated_delta_rule_ref", +] diff --git a/python/sgl_jax/srt/kernels/gdn/gated_delta.py b/python/sgl_jax/srt/kernels/gdn/gated_delta.py new file mode 100644 index 0000000000..30f24b7cf0 --- /dev/null +++ b/python/sgl_jax/srt/kernels/gdn/gated_delta.py @@ -0,0 +1,522 @@ +"""Gated Delta-Rule primitives, kernels, and causal conv1d helpers for Qwen3-Next. + +The recurrence math matches HuggingFace +``torch_recurrent_gated_delta_rule`` (``transformers/models/qwen3_next/ +modeling_qwen3_next.py``): + + scale = 1 / sqrt(K) + q_t = q_t * scale + S_t = S_{t-1} * exp(g_t) + kv_mem_t = (S_t * k_t[..., None]).sum(axis=-2) + delta_t = (v_t - kv_mem_t) * beta_t + S_t = S_t + k_t[..., None] * delta_t[..., None, :] + o_t = (S_t * q_t[..., None]).sum(axis=-2) + +with ``S`` stored in ``[K, V]`` order per-head. + +Public entry points: + +Recurrence kernels — both take post-conv ``mixed_qkv`` plus the full +per-layer ``recurrent_state`` table and ``state_indices``, gather per-seq +state internally, and return per-request new state plus per-token output: + +* :func:`ragged_gated_delta_rule_ref` — token-by-token ``lax.scan`` over a + packed ragged batch (used in extend / chunked-prefill paths). +* :func:`decode_gated_delta_rule_ref` — single recurrence step parallelised + across the batch axis (used in the decode fast path; one token per + request, no scan). + +Conv1d helpers that front-run the delta rule: + +* :func:`jax_causal_conv1d_prefill` — depthwise causal conv1d over a + ragged-batched packed sequence; gathers per-seq prior state from a full + per-layer table. +* :func:`jax_causal_conv1d_update` — single-token causal conv1d update for + decode; takes per-request state directly. + +Internal helper :func:`_gated_delta_step` is leading-dim-agnostic and +shared between the two recurrence kernels. +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp + + +def _l2norm(x: jax.Array, eps: float = 1e-6) -> jax.Array: + norm = jnp.sqrt((x.astype(jnp.float32) ** 2).sum(axis=-1, keepdims=True) + eps) + return (x.astype(jnp.float32) / norm).astype(x.dtype) + + +def _gated_delta_step( + state: jax.Array, # [..., H, K, V] float32 + q_t: jax.Array, # [..., H, K] + k_t: jax.Array, # [..., H, K] + v_t: jax.Array, # [..., H, V] + g_t: jax.Array, # [..., H] log-decay + beta_t: jax.Array, # [..., H] +) -> tuple[jax.Array, jax.Array]: + """Single gated delta step. + + Leading-dim-agnostic — broadcasts over any prefix shape (e.g. ``[B, H]`` + for the dense scan, ``[H]`` for a per-token ragged scan). Returns + ``(new_state [..., H, K, V], out [..., H, V])``. + """ + decay = jnp.exp(g_t.astype(jnp.float32))[..., None, None] # [..., H, 1, 1] + state = state * decay + kv_mem = jnp.einsum("...hkv,...hk->...hv", state, k_t) + delta = (v_t - kv_mem) * beta_t[..., None] # [..., H, V] + # Outer product across K and V: k along K axis × delta along V axis. + state = state + k_t[..., None] * delta[..., None, :] # [..., H, K, V] + out = jnp.einsum("...hkv,...hk->...hv", state, q_t) + return state, out + + +# --------------------------------------------------------------------------- +# Causal conv1d (depthwise, kernel_size=K, stride=1, dilation=1) +# --------------------------------------------------------------------------- + + +def jax_causal_conv1d_prefill( + x: jax.Array, # [D, T] packed activations + weight: jax.Array, # [D, kernel_size] depthwise weight + bias: jax.Array | None = None, # [D] optional + cu_seqlens: jax.Array | None = None, # [B+1] + conv_state: jax.Array | None = None, # [num_blocks, D, kernel_size-1] full per-layer table + state_indices: jax.Array | None = None, # [B] req → slot + has_initial_state: jax.Array | None = None, # [B] bool + activation: str | None = None, +) -> tuple[jax.Array, jax.Array]: + """Depthwise causal conv1d over a ragged-batched packed sequence. + + Sequences are concatenated along the token axis. Boundaries are given by + ``cu_seqlens`` (``[0, len_0, len_0+len_1, ...]``). Each output position + only mixes inputs from its own request — boundary lookbacks are served + from ``conv_state`` (gathered via ``state_indices``) if provided, else + zero. + + ``has_initial_state``: ``[B]`` bool. ``True`` when the slot already + holds valid conv state from a previous chunk (chunked-prefill + continuation, prefix-cache hit, or running decode). ``False`` for + brand-new prefills, in which case the gathered slot is masked to + zero before use — the slot may hold stale state from a previously- + freed request and reading it would corrupt both the per-token + lookback AND the final-state left-pad for short requests. Mirrors + the same mask the recurrence ref applies. If ``None`` (the default), + behavior is "all True" (existing-state mode) to preserve backward + compatibility with callers that don't track this. + + Returns ``(y [D, T], new_conv_state)``. ``new_conv_state`` holds the + last ``K-1`` logical tokens of each request, scattered back into the + full pool table at ``state_indices`` — its shape is + ``[num_blocks, D, K-1]`` (the input ``conv_state``'s shape) and its + dtype matches the pool. When ``conv_state`` is ``None`` (test + fixture mode with no pool), ``new_conv_state`` falls back to the + per-request ``[B, D, K-1]`` slice. The scatter happens inside the + kernel so the same shape contract holds when this ref is later + replaced by a Pallas kernel that writes directly into pool buffers. + """ + if activation not in (None, "silu"): + raise ValueError(f"Unsupported causal conv1d activation: {activation}") + + D, T = x.shape + K = int(weight.shape[1]) + assert cu_seqlens is not None, "cu_seqlens is required" + B = int(cu_seqlens.shape[0]) - 1 + assert weight.shape == (D, K), f"weight {weight.shape} vs x {x.shape}" + assert (conv_state is None) == ( + state_indices is None + ), "conv_state and state_indices must be provided together" + + if conv_state is not None: + assert conv_state.shape[1:] == ( + D, + K - 1, + ), f"conv_state {conv_state.shape} channels/kernel != ({D}, {K - 1})" + assert state_indices.shape == ( + B, + ), f"state_indices {state_indices.shape} != expected ({B},)" + # Gather per-seq prior state once up front; later lookups index by + # local seq id rather than walking the full table per token. + state = conv_state[state_indices] # [B, D, K-1] + # Brand-new prefills (has_initial_state=False) must not read stale + # slot contents — the slot was previously owned by another request + # before allocation. Zero them out so both the per-token lookback + # gather AND the final-state left-pad treat them as fresh. + if has_initial_state is not None and K > 1: + assert has_initial_state.shape == ( + B, + ), f"has_initial_state {has_initial_state.shape} != expected ({B},)" + state = jnp.where( + has_initial_state[:, None, None], + state, + jnp.zeros_like(state), + ) + else: + state = None + + starts = cu_seqlens[:-1] # [B] inclusive + ends = cu_seqlens[1:] # [B] exclusive + seq_lens = ends - starts # [B] + + # Map each packed token index to its request id and intra-request position. + t_idx = jnp.arange(T) + seq_idx = jnp.searchsorted(cu_seqlens, t_idx, side="right") - 1 # [T] + pos = t_idx - starts[seq_idx] # [T] + + # Build the depthwise window. For each lookback o in [0, K-1] the source + # logical position is p' = pos[t] - o; in-request when p' >= 0, otherwise + # the lookback predates this batch and must come from the saved + # `conv_state`. The state holds the K-1 most-recent pre-batch tokens + # with newest at index K-2, so logical position p' (negative when + # pre-batch) maps to state slot (K-1) + p'. + o = jnp.arange(K) + src_t = t_idx[:, None] - o[None, :] # [T, K] + in_seq = src_t >= starts[seq_idx][:, None] # [T, K] + src_t_safe = jnp.clip(src_t, 0, T - 1) + x_gathered = x[:, src_t_safe] # [D, T, K] + + if state is not None and K > 1: + p_prime = pos[:, None] - o[None, :] # [T, K] + is_idx = jnp.clip((K - 1) + p_prime, 0, K - 2) # [T, K] + # Advanced indexing into [B, D, K-1] with two index arrays of shape + # [T, K] (seq_idx broadcast and is_idx) plus a full slice on D + # yields [T, K, D] (the slice axis trails the advanced ones per + # numpy rules). Transpose back to [D, T, K] to match `x_gathered`. + init_pulled = state[seq_idx[:, None], :, is_idx] # [T, K, D] + init_pulled = jnp.transpose(init_pulled, (2, 0, 1)) # [D, T, K] + x_gathered = jnp.where(in_seq[None], x_gathered, init_pulled) + elif K > 1: + x_gathered = jnp.where(in_seq[None], x_gathered, jnp.zeros_like(x_gathered)) + # K == 1: no lookback, `src_t == t_idx` and `in_seq` is all-True; no + # masking needed. + + # weight[d, K-1-o] is the coefficient for lookback o. + w_flipped = weight[:, ::-1].astype(x.dtype) # [D, K] + y = jnp.einsum("dtk,dk->dt", x_gathered, w_flipped) + if bias is not None: + y = y + bias.astype(x.dtype)[:, None] + if activation == "silu": + y = jax.nn.silu(y) + + # Final state: the K-1 most-recent logical tokens of each request. + # logical_idx[b, j] = (seq_lens[b] - (K-1)) + j, indexing into the per- + # request "logical token stream" (state-padding ++ in-batch tokens). + # When >= 0 the token came from x; when < 0 the token came from the + # prior conv_state (or zero pad). + if K > 1: + j = jnp.arange(K - 1)[None, :] # [1, K-1] + logical_idx = seq_lens[:, None] - (K - 1) + j # [B, K-1] + take_from_x = logical_idx >= 0 + src_t_end_safe = jnp.clip(starts[:, None] + logical_idx, 0, T - 1) + from_x = jnp.transpose(x[:, src_t_end_safe], (1, 0, 2)) # [B, D, K-1] + if state is not None: + is_slot = jnp.clip((K - 1) + logical_idx, 0, K - 2) # [B, K-1] + b_idx = jnp.arange(B)[:, None] + # Same advanced-indexing-with-slice trick as the per-token + # gather above: result is [B, K-1, D]; transpose to [B, D, K-1]. + from_init = state[b_idx, :, is_slot] # [B, K-1, D] + from_init = jnp.transpose(from_init, (0, 2, 1)) # [B, D, K-1] + final_state = jnp.where(take_from_x[:, None, :], from_x, from_init) + else: + final_state = jnp.where(take_from_x[:, None, :], from_x, jnp.zeros_like(from_x)) + else: + # K == 1: the conv has no left context, so the "state" is empty. + final_state = jnp.zeros((B, D, 0), dtype=x.dtype) + + # Scatter the per-request final state back into the full pool table. + # The fixture-only `conv_state is None` path returns the per-request + # slice as-is for tests that don't construct a pool. + if conv_state is not None: + new_conv_state = conv_state.at[state_indices].set(final_state.astype(conv_state.dtype)) + else: + new_conv_state = final_state + return y, new_conv_state + + +def jax_causal_conv1d_update( + x: jax.Array, # [B, D] one new token per batch element + conv_state: jax.Array, # [num_blocks, D, kernel_size-1] full per-layer table + state_indices: jax.Array, # [B] req → slot + weight: jax.Array, # [D, kernel_size] + bias: jax.Array | None = None, # [D] + activation: str | None = None, +) -> tuple[jax.Array, jax.Array]: + """Single-token causal conv1d update. + + State contract mirrors :func:`jax_causal_conv1d_prefill`: the caller + passes the full per-layer ``conv_state`` table plus ``state_indices``; + the per-request slice is gathered, updated, and scattered back inside + the kernel. Returns ``(y [B, D], new_conv_state)`` where + ``new_conv_state`` is the full pool table + ``[num_blocks, D, kernel_size-1]`` with the per-request slots + updated. Doing the scatter inside the kernel keeps the same shape + contract when this ref is later replaced by a Pallas kernel. + + No ``has_initial_state`` mask here: decode always runs after at least + one extend chunk, so every slot's conv state is valid by invariant + (same lifecycle argument as :func:`decode_gated_delta_rule_ref`). + Brand-new prefills are routed through :func:`jax_causal_conv1d_prefill`, + which *does* honor ``has_initial_state``. + """ + assert x.ndim == 2, f"x must be [B, D], got shape {x.shape}" + B, D = x.shape + kernel = int(weight.shape[1]) + assert conv_state.shape[1:] == ( + D, + kernel - 1, + ), f"conv_state {conv_state.shape} channels/kernel != ({D}, {kernel - 1})" + assert state_indices.shape == (B,), f"state_indices {state_indices.shape} != expected ({B},)" + + state = conv_state[state_indices] # [B, D, K-1] + # Rolling buffer: [state(kernel-1), x_new] → window of length kernel. + window = jnp.concatenate([state, x[..., None]], axis=-1) # [B, D, K] + y = jnp.einsum("bdk,dk->bd", window, weight.astype(x.dtype)) + if bias is not None: + y = y + bias.astype(x.dtype)[None, :] + if activation == "silu": + y = jax.nn.silu(y) + elif activation is None: + pass + else: + raise ValueError(f"Unsupported causal conv1d activation: {activation}") + new_state = window[..., 1:] # drop oldest + + # Scatter the per-request new state back into the full pool table. + new_conv_state = conv_state.at[state_indices].set(new_state.astype(conv_state.dtype)) + return y, new_conv_state + + +# --------------------------------------------------------------------------- +# Recurrence kernels (extend + decode reference implementations) +# --------------------------------------------------------------------------- + + +def ragged_gated_delta_rule_ref( + mixed_qkv: jax.Array, + b: jax.Array, + a: jax.Array, + recurrent_state: jax.Array, + A_log: jax.Array, + dt_bias: jax.Array, + cu_seqlens: jax.Array, + state_indices: jax.Array, + has_initial_state: jax.Array, + *, + n_kq: int, + n_v: int, + d_k: int, + d_v: int, +) -> tuple[jax.Array, jax.Array]: + """Ragged gated delta-rule forward (extend / chunked-prefill). + + Token-by-token ``jax.lax.scan`` over a packed ragged batch. Boundaries + are given by ``cu_seqlens``; ``cu_seqlens[-1]`` is the valid-token + count (tokens past it are padding and have their state writes gated + off by ``valid_mask``). + + Contract: the kernel both gathers from and scatters back into the + full per-layer recurrent-state table — callers receive an updated + full table, no per-request scatter step needed at the caller. This + matches what an eventual Pallas kernel would do at the HLO boundary + (kernels output into pool buffers directly) and mirrors + ``tpu_inference.layers.common.ragged_conv1d_jax``'s + gather-once / scatter-once shape (we avoid tpu-inference's + full-table scan-carry pattern because the scan can keep its small + ``[B, ...]`` working buffer; we only scatter at the end). + + Args: + mixed_qkv: Packed ``(Q | K | V)`` tokens of shape + ``[num_tokens, 2 * n_kq * d_k + n_v * d_v]``. Q/K are stored at + ``n_kq`` heads; expansion to ``n_v`` heads happens inside this + function (so callers should not pre-repeat). + b: Pre-sigmoid beta input, ``[num_tokens, n_v]``. + a: Pre-softplus delta-t input, ``[num_tokens, n_v]``. + recurrent_state: Full per-layer state table, + ``[num_blocks, n_v, d_k, d_v]``. + A_log: ``[n_v]`` log-A parameter. + dt_bias: ``[n_v]`` delta-t bias. + cu_seqlens: ``[B + 1]`` int32 cumulative sequence lengths in the + packed buffer; ``cu_seqlens[-1]`` is the number of valid + (non-padding) tokens. + state_indices: ``[B]`` int32 mapping request index to slot in + ``recurrent_state``. + has_initial_state: ``[B]`` bool. ``True`` when the slot already + holds a valid recurrent state (chunked-prefill continuation, + prefix-cache hit, or running decode); ``False`` for brand-new + prefills, which must start from zero regardless of stale slot + contents. Mirrors GPU's + ``initial_state[~has_initial_state, ...] = 0``. + n_kq: Number of key/query heads (per-shard). + n_v: Number of value heads (per-shard). Must be a multiple of n_kq. + d_k: Per-head key/query dim. + d_v: Per-head value dim. + + Returns: + ``(new_recurrent_state, output)`` where ``new_recurrent_state`` + is the full pool table ``[num_blocks, n_v, d_k, d_v]`` (same + dtype as the input ``recurrent_state``) with the per-request + slots updated in place, and ``output`` has shape + ``[num_tokens, n_v, d_v]`` in ``mixed_qkv.dtype``. + """ + num_tokens = mixed_qkv.shape[0] + B = state_indices.shape[0] + key_dim = n_kq * d_k + + # Slice + reshape + (optional) GQA expand ONCE outside the scan. Keeping + # these out of the per-token body keeps the sharding inference on stable + # ground: `query`/`key`/`value` arrive at the scan already shaped + # ``[T, n_v, d_k]`` / ``[T, n_v, d_v]`` with the tensor axis pinned to + # the head dim. Doing the reshape per step under explicit sharding lets + # JAX place ``"tensor"`` on the wrong axis when ``n_kq == 1`` (a 1-of-N + # split), which then breaks the outer-product step inside + # ``_gated_delta_step``. + query = mixed_qkv[..., :key_dim].reshape(num_tokens, n_kq, d_k) + key = mixed_qkv[..., key_dim : 2 * key_dim].reshape(num_tokens, n_kq, d_k) + value = mixed_qkv[..., 2 * key_dim :].reshape(num_tokens, n_v, d_v) + repeat_factor = n_v // n_kq + if repeat_factor > 1: + query = jnp.repeat(query, repeat_factor, axis=1) + key = jnp.repeat(key, repeat_factor, axis=1) + + last_valid_loc = cu_seqlens[-1] + token_idx = jnp.arange(num_tokens) + req_indices = jnp.searchsorted(cu_seqlens[1:], token_idx, side="right") + # Padding tokens (idx >= last_valid_loc) get clamped to a valid local seq + # id; their writes are gated off via `valid_mask` so the slot they "read" + # is irrelevant. + req_indices = jnp.clip(req_indices, 0, B - 1) + valid_mask = token_idx < last_valid_loc + + # Gather per-seq initial state once, then mask brand-new prefills to + # zero. Mirrors GPU's `initial_state[~has_initial_state, ...] = 0`. + init_state = recurrent_state[state_indices].astype(jnp.float32) + init_state = jnp.where( + has_initial_state[:, None, None, None], + init_state, + jnp.zeros_like(init_state), + ) + + A = jnp.exp(A_log.astype(jnp.float32)) + dt_bias_f32 = dt_bias.astype(jnp.float32) + scale = d_k**-0.5 + + def scan_fn(state_buf, xs): + # state_buf: [B, n_v, d_k, d_v] + q_h, k_h, v_h, b_t, a_t, req_idx, is_valid = xs + + state = state_buf[req_idx] # [n_v, d_k, d_v] + + # Cast to fp32 inside the kernel to match GPU's + # `fused_gdn_gating_kernel`. + q_h = _l2norm(q_h.astype(jnp.float32)) * scale + k_h = _l2norm(k_h.astype(jnp.float32)) + v_h = v_h.astype(jnp.float32) + beta = jax.nn.sigmoid(b_t.astype(jnp.float32)) + g = -A * jax.nn.softplus(a_t.astype(jnp.float32) + dt_bias_f32) + + # _gated_delta_step uses `...` and negative axes throughout, so it + # accepts any leading-dim shape — including no batch dim. + new_state, out = _gated_delta_step(state, q_h, k_h, v_h, g, beta) + + new_state_buf = jnp.where( + is_valid, + state_buf.at[req_idx].set(new_state), + state_buf, + ) + return new_state_buf, out.astype(mixed_qkv.dtype) + + new_state_buf, output = jax.lax.scan( + scan_fn, + init_state, + (query, key, value, b, a, req_indices, valid_mask), + ) + + # Scatter the per-request final states back into the full pool table + # in one shot. The scan body kept its working buffer small + # (``[B, n_v, d_k, d_v]``); only this one ``at[].set`` touches the + # ``[num_blocks, ...]`` table. Cast to the pool's dtype (the scan + # carries fp32; the pool is typically fp32 too, but + # ``SGLANG_JAX_RECURRENT_STATE_DTYPE=bfloat16`` can override). + new_recurrent_state = recurrent_state.at[state_indices].set( + new_state_buf.astype(recurrent_state.dtype) + ) + return new_recurrent_state, output + + +def decode_gated_delta_rule_ref( + mixed_qkv: jax.Array, + b: jax.Array, + a: jax.Array, + recurrent_state: jax.Array, + A_log: jax.Array, + dt_bias: jax.Array, + state_indices: jax.Array, + *, + n_kq: int, + n_v: int, + d_k: int, + d_v: int, +) -> tuple[jax.Array, jax.Array]: + """Decode-only gated delta-rule (parallel single-step across the batch). + + One token per request, no cross-token dependencies — so we run a + single ``_gated_delta_step`` parallelised across the batch axis + instead of feeding :func:`ragged_gated_delta_rule_ref` with + ``cu_seqlens=arange(B+1)`` (which would serialise B independent steps + as a ``T=B`` scan). Numerically equivalent to that path; just faster. + + Decode always runs after at least one extend, so every slot already + holds valid recurrent state — there is no ``has_initial_state`` mask + here (the equivalent argument would always be all-``True``). + + Args: + mixed_qkv: Post-conv tokens of shape + ``[B, 2 * n_kq * d_k + n_v * d_v]`` (one token per request). + b: Pre-sigmoid beta input, ``[B, n_v]``. + a: Pre-softplus delta-t input, ``[B, n_v]``. + recurrent_state: Full per-layer state table, + ``[num_blocks, n_v, d_k, d_v]``. + A_log: ``[n_v]`` log-A parameter. + dt_bias: ``[n_v]`` delta-t bias. + state_indices: ``[B]`` int32 mapping request index to slot. + n_kq, n_v, d_k, d_v: head/dim configuration (see + :func:`ragged_gated_delta_rule_ref`). + + Returns: + ``(new_recurrent_state, output)`` where ``new_recurrent_state`` + is the full pool table ``[num_blocks, n_v, d_k, d_v]`` (same + dtype as the input ``recurrent_state``) with the per-request + slots updated in place, and ``output`` has shape + ``[B, n_v, d_v]`` in ``mixed_qkv.dtype``. + """ + B = mixed_qkv.shape[0] + key_dim = n_kq * d_k + q = mixed_qkv[:, :key_dim].reshape(B, n_kq, d_k) + k = mixed_qkv[:, key_dim : 2 * key_dim].reshape(B, n_kq, d_k) + v = mixed_qkv[:, 2 * key_dim :].reshape(B, n_v, d_v) + repeat_factor = n_v // n_kq + if repeat_factor > 1: + q = jnp.repeat(q, repeat_factor, axis=1) + k = jnp.repeat(k, repeat_factor, axis=1) + + # Cast to fp32 inside the kernel to match GPU's + # `fused_gdn_gating_kernel` and the ragged-ref path. + scale = d_k**-0.5 + q_h = _l2norm(q.astype(jnp.float32)) * scale + k_h = _l2norm(k.astype(jnp.float32)) + v_h = v.astype(jnp.float32) + A = jnp.exp(A_log.astype(jnp.float32)) + dt_bias_f32 = dt_bias.astype(jnp.float32) + beta = jax.nn.sigmoid(b.astype(jnp.float32)) + g = -A * jax.nn.softplus(a.astype(jnp.float32) + dt_bias_f32) + + state = recurrent_state[state_indices].astype(jnp.float32) + new_state, out = _gated_delta_step(state, q_h, k_h, v_h, g, beta) + + # Scatter the per-request new state back into the full pool table. + new_recurrent_state = recurrent_state.at[state_indices].set( + new_state.astype(recurrent_state.dtype) + ) + return new_recurrent_state, out.astype(mixed_qkv.dtype) diff --git a/python/sgl_jax/srt/layers/attention/linear/gdn_backend.py b/python/sgl_jax/srt/layers/attention/linear/gdn_backend.py new file mode 100644 index 0000000000..fe26d7d276 --- /dev/null +++ b/python/sgl_jax/srt/layers/attention/linear/gdn_backend.py @@ -0,0 +1,377 @@ +"""Gated-DeltaNet attention backend. + +Inherits :class:`LinearRecurrentAttnBackend` for shared metadata +(``cu_q_lens`` / ``recurrent_indices`` / ``has_initial_state``) and +pytree boilerplate. Owns the (fused) conv1d weight + delta-rule params +(``A_log``, ``dt_bias``); the parent layer hands in ``mixed_qkv`` +(a per-device block-concat ``[Q | K | V]`` of size ``conv_dim`` +channels) plus ``b``, ``a``, and a :class:`RecurrentStatePool`. State +(conv + recurrent) is fetched from the pool internally via the base +class's :meth:`get_layer_cache` helper. + +Sharding pattern: the conv + recurrence pipeline runs inside +:func:`jax.shard_map` with explicit ``in_specs`` / ``out_specs``, with +the head axis pinned to ``"tensor"`` so each device sees only its local +shard. The kernels then operate on per-shard head counts +(``n_kq // TP``, ``n_v // TP``) without relying on JAX sharding +inference. Returns ``(core_attn_out, new_conv, new_rec)`` shaped for +``RecurrentStatePool.write_layer``. +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from flax import nnx +from jax.sharding import PartitionSpec as P + +from sgl_jax.srt.kernels.gdn import ( + decode_gated_delta_rule_ref, + jax_causal_conv1d_prefill, + jax_causal_conv1d_update, + ragged_gated_delta_rule_ref, +) +from sgl_jax.srt.layers.attention.hybrid_linear_attn_backend import ( + LinearRecurrentAttnBackend, +) +from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch + + +def _mesh_tp_size(mesh: jax.sharding.Mesh) -> int: + """TP size = mesh size on the ``"tensor"`` axis (1 if absent).""" + if mesh is None: + return 1 + shape = getattr(mesh, "shape", None) + if shape is None or "tensor" not in shape: + return 1 + return int(shape["tensor"]) + + +class GDNAttnBackend(LinearRecurrentAttnBackend): + """Gated-DeltaNet attention backend. + + Owns the conv1d weight + delta-rule params; dispatches conv1d + ragged + delta-rule (extend) or single-step delta-rule (decode) under + ``jax.shard_map``. Reads ``cu_q_lens`` / ``recurrent_indices`` / + ``has_initial_state`` from ``self.forward_metadata``, populated by + the base class's :meth:`get_forward_metadata` before each forward. + """ + + def __init__( + self, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_kernel_size: int, + mesh: jax.sharding.Mesh, + dtype: jnp.dtype = jnp.bfloat16, + ): + super().__init__(mesh=mesh) + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.conv_kernel_size = conv_kernel_size + + self.key_dim = num_k_heads * head_k_dim + self.value_dim = num_v_heads * head_v_dim + self.conv_dim = 2 * self.key_dim + self.value_dim + + # Per-shard slicing in the kernels uses `num_*_heads // tp` (integer + # division). `MergedColumnParallelLinear` only checks that `key_dim` + # and `value_dim` are divisible by TP, which is not enough: e.g. + # `num_k_heads=1, head_k_dim=128, TP=2` gives `key_dim=128` (divisible) + # but `num_k_heads // TP = 0`, and the per-shard reshape silently + # produces zero-head arrays. GQA also relies on `num_v_heads % + # num_k_heads == 0` so the per-step `jnp.repeat` produces exactly + # `num_v_heads` heads. + tp = _mesh_tp_size(mesh) + if num_k_heads % tp != 0: + raise ValueError( + f"GDNAttnBackend: num_k_heads={num_k_heads} must be divisible " f"by TP={tp}." + ) + if num_v_heads % tp != 0: + raise ValueError( + f"GDNAttnBackend: num_v_heads={num_v_heads} must be divisible " f"by TP={tp}." + ) + if self.conv_dim % tp != 0: + raise ValueError( + f"GDNAttnBackend: conv_dim={self.conv_dim} must be divisible " + f"by TP={tp} for clean per-shard channel slicing." + ) + if num_v_heads % num_k_heads != 0: + raise ValueError( + f"GDNAttnBackend: num_v_heads={num_v_heads} must be a multiple " + f"of num_k_heads={num_k_heads} (GQA repeat factor)." + ) + + # Depthwise conv1d weight (HF stores [conv_dim, 1, K]; we squeeze). + # Sharded on the conv_dim axis so each TP rank owns its own channels + # — consistent with how `RecurrentStatePool` shards conv_state. + # + # IMPORTANT — per-shard channel layout is a *loader contract*, not a + # property of this Param. The shard_map calls below use + # ``in_specs=P("tensor", None)``, which slices axis 0 into TP + # contiguous chunks. For the conv1d to line up with ``mixed_qkv``, + # each rank's local rows must be the per-shard block-concat + # ``[q_tp | k_tp | v_tp]`` (the same convention + # :class:`MergedColumnParallelLinear` produces for ``in_proj_qkv``). + # The HF checkpoint stores conv1d as a single + # ``[global_q | global_k | global_v]`` block along ``conv_dim`` — a + # naive ``device_put`` with ``P("tensor", None)`` would give rank 0 + # mostly Q channels and rank N-1 mostly V, which silently mismatches + # the per-shard activation layout and produces wrong outputs at + # TP > 1 with no crash. + # + # The model loader (CUDA sglang reference: + # ``mamba_v2_sharded_weight_loader`` in + # ``sglang/srt/layers/attention/mamba/mamba.py``) must therefore + # stripe-rearrange the HF tensor so each rank's local rows are + # ``[Q[rank * key_dim/TP : (rank+1) * key_dim/TP] + # | K[rank * key_dim/TP : (rank+1) * key_dim/TP] + # | V[rank * value_dim/TP : (rank+1) * value_dim/TP]]`` + # before placement. ``in_proj_qkv`` follows the same convention. + # + # A TP > 1 numerical test against an fp32 reference is the canary + # for getting this wrong — at TP = 1 the two layouts coincide and + # bugs hide. + self.conv1d_weight = nnx.Param(jnp.zeros((self.conv_dim, conv_kernel_size), dtype=dtype)) + # Delta-rule params, sharded per-head. Storage dtypes follow the HF + # Qwen3.5 checkpoint exactly: + # A_log: fp32 (the recurrence's ``-exp(A_log)`` factor is + # numerically sensitive — checkpoint is fp32 and the + # gating kernel reads it as such). + # dt_bias: model dtype (bf16). The gating kernel upcasts to fp32 + # internally for ``softplus(a + dt_bias)``; storing fp32 + # here would only force a load-time cast and double the + # param footprint with no numerical benefit. + self.A_log = nnx.Param(jnp.zeros((num_v_heads,), dtype=jnp.float32)) + self.dt_bias = nnx.Param(jnp.ones((num_v_heads,), dtype=dtype)) + + # ------------------------------------------------------------------ + # Dispatch + # ------------------------------------------------------------------ + + def __call__( + self, + forward_batch: ForwardBatch, + mixed_qkv: jax.Array, # [T, conv_dim] (None, "tensor") + b: jax.Array, # [T, n_v] (None, "tensor") + a: jax.Array, # [T, n_v] (None, "tensor") + recurrent_state_pool, + layer_id: int, + ) -> tuple[jax.Array, jax.Array, jax.Array]: + """Dispatch by ``forward_batch.forward_mode``. + + Fetches per-layer ``(recurrent_state, conv_state)`` from the pool + via the base class's :meth:`get_layer_cache`. ``conv_state`` is the + first (only) entry of the per-layer conv-state list — GDN uses a + single fused conv1d, so it needs exactly one conv buffer per layer + (vs. KDA, which keeps q/k/v conv states as three list entries). + + Returns ``(core_attn_out, new_conv_state, new_rec_state)`` where + ``new_conv_state`` and ``new_rec_state`` are the full pool tables + with this layer's per-request slots updated (scatter happens + inside the kernel — see :func:`ragged_gated_delta_rule_ref` / + :func:`jax_causal_conv1d_prefill` and the decode-path equivalents). + Caller writes these back onto the pool (e.g. via + ``RecurrentStatePool.replace_buffer``). + """ + recurrent_state, conv_states = self.get_layer_cache(recurrent_state_pool, layer_id) + conv_state = conv_states[0] + + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(mixed_qkv, conv_state, recurrent_state, b, a) + return self.forward_extend(mixed_qkv, conv_state, recurrent_state, b, a) + + # ------------------------------------------------------------------ + # Decode fast path + # ------------------------------------------------------------------ + + def forward_decode( + self, + mixed_qkv: jax.Array, + conv_state_in: jax.Array, + recurrent_state_in: jax.Array, + b: jax.Array, + a: jax.Array, + ) -> tuple[jax.Array, jax.Array, jax.Array]: + """One token per request — single conv1d update + parallel single + recurrence step across the batch, all inside a shard_map.""" + state_indices = self.forward_metadata.recurrent_indices + tp = _mesh_tp_size(self.mesh) + n_kq_tp = self.num_k_heads // tp + n_v_tp = self.num_v_heads // tp + d_k = self.head_k_dim + d_v = self.head_v_dim + + def _decode_local( + mixed_qkv_l, + conv_state_l, + rec_state_l, + conv_weight_l, + A_log_l, + dt_bias_l, + b_l, + a_l, + state_indices_l, + ): + conv_out, new_conv = jax_causal_conv1d_update( + mixed_qkv_l, + conv_state_l, + state_indices_l, + conv_weight_l, + bias=None, + activation="silu", + ) + new_rec, out = decode_gated_delta_rule_ref( + conv_out, + b_l, + a_l, + rec_state_l, + A_log_l, + dt_bias_l, + state_indices_l, + n_kq=n_kq_tp, + n_v=n_v_tp, + d_k=d_k, + d_v=d_v, + ) + return out, new_conv, new_rec + + return jax.shard_map( + _decode_local, + mesh=self.mesh, + in_specs=( + P(None, "tensor"), # mixed_qkv + P(None, "tensor", None), # conv_state + P(None, "tensor", None, None), # recurrent_state + P("tensor", None), # conv1d weight + P("tensor"), # A_log + P("tensor"), # dt_bias + P(None, "tensor"), # b + P(None, "tensor"), # a + P(), # state_indices (replicated) + ), + out_specs=( + P(None, "tensor", None), # out [B, n_v, d_v] + P(None, "tensor", None), # new_conv_state [num_blocks, conv_dim, K-1] + P(None, "tensor", None, None), # new_rec_state [num_blocks, n_v, d_k, d_v] + ), + check_vma=False, + )( + mixed_qkv, + conv_state_in, + recurrent_state_in, + self.conv1d_weight.value, + self.A_log.value, + self.dt_bias.value, + b, + a, + state_indices, + ) + + # ------------------------------------------------------------------ + # Extend / chunked-prefill + # ------------------------------------------------------------------ + + def forward_extend( + self, + mixed_qkv: jax.Array, + conv_state_in: jax.Array, + recurrent_state_in: jax.Array, + b: jax.Array, + a: jax.Array, + ) -> tuple[jax.Array, jax.Array, jax.Array]: + """Packed ragged batch through ``ragged_gated_delta_rule_ref``.""" + meta = self.forward_metadata + cu_seqlens = meta.cu_q_lens + state_indices = meta.recurrent_indices + has_initial_state = meta.has_initial_state + tp = _mesh_tp_size(self.mesh) + n_kq_tp = self.num_k_heads // tp + n_v_tp = self.num_v_heads // tp + d_k = self.head_k_dim + d_v = self.head_v_dim + + def _extend_local( + mixed_qkv_l, + conv_state_l, + rec_state_l, + conv_weight_l, + A_log_l, + dt_bias_l, + b_l, + a_l, + cu_seqlens_l, + state_indices_l, + has_initial_state_l, + ): + # jax_causal_conv1d_prefill operates on [D, T] (channel-first). + # Pass `has_initial_state` so brand-new prefills don't pick up + # stale conv state from a freshly-allocated slot (same mask + # contract as `ragged_gated_delta_rule_ref`). + conv_out_dt, new_conv = jax_causal_conv1d_prefill( + x=mixed_qkv_l.T, + weight=conv_weight_l, + bias=None, + cu_seqlens=cu_seqlens_l, + conv_state=conv_state_l, + state_indices=state_indices_l, + has_initial_state=has_initial_state_l, + activation="silu", + ) + conv_out = conv_out_dt.T # [T, D] + new_rec, out = ragged_gated_delta_rule_ref( + conv_out, + b_l, + a_l, + rec_state_l, + A_log_l, + dt_bias_l, + cu_seqlens=cu_seqlens_l, + state_indices=state_indices_l, + has_initial_state=has_initial_state_l, + n_kq=n_kq_tp, + n_v=n_v_tp, + d_k=d_k, + d_v=d_v, + ) + return out, new_conv, new_rec + + return jax.shard_map( + _extend_local, + mesh=self.mesh, + in_specs=( + P(None, "tensor"), # mixed_qkv + P(None, "tensor", None), # conv_state + P(None, "tensor", None, None), # recurrent_state + P("tensor", None), # conv1d weight + P("tensor"), # A_log + P("tensor"), # dt_bias + P(None, "tensor"), # b + P(None, "tensor"), # a + P(), # cu_seqlens (replicated) + P(), # state_indices (replicated) + P(), # has_initial_state (replicated) + ), + out_specs=( + P(None, "tensor", None), # out [T, n_v, d_v] + P(None, "tensor", None), # new_conv_state [num_blocks, conv_dim, K-1] + P(None, "tensor", None, None), # new_rec_state [num_blocks, n_v, d_k, d_v] + ), + check_vma=False, + )( + mixed_qkv, + conv_state_in, + recurrent_state_in, + self.conv1d_weight.value, + self.A_log.value, + self.dt_bias.value, + b, + a, + cu_seqlens, + state_indices, + has_initial_state, + ) diff --git a/python/sgl_jax/srt/layers/attention/linear/qwen3_5_gated_delta_net.py b/python/sgl_jax/srt/layers/attention/linear/qwen3_5_gated_delta_net.py new file mode 100644 index 0000000000..4c0c4f2e1f --- /dev/null +++ b/python/sgl_jax/srt/layers/attention/linear/qwen3_5_gated_delta_net.py @@ -0,0 +1,254 @@ +"""Qwen3-5 Gated DeltaNet linear-attention layer. + +Shape-correct port of HuggingFace's ``Qwen3_5GatedDeltaNet``. The HF +checkpoint has four separate projections (a key difference from +Qwen3-Next, which fuses everything into ``in_proj_qkvz``):: + + in_proj_qkv: [hidden, 2*key_dim + value_dim] ([Q | K | V] block-concat) + in_proj_z: [hidden, value_dim] + in_proj_b: [hidden, num_v_heads] + in_proj_a: [hidden, num_v_heads] + conv1d.weight: [conv_dim, kernel_size] + A_log: [num_v_heads] + dt_bias: [num_v_heads] + norm.weight: [head_v_dim] + out_proj: [value_dim, hidden] + +We use :class:`MergedColumnParallelLinear` for ``in_proj_qkv`` so the +single GEMM uses TPU's MXU better than three smaller ones. Q/K/V are +declared as three components with sizes ``[key_dim, key_dim, value_dim]``; +each component shards on its own head dim, and the per-device output is +already laid out as block-concat ``[q_tp | k_tp | v_tp]`` of size +``conv_dim/TP`` — exactly what the conv1d + recurrence pipeline want. +``z`` / ``b`` / ``a`` stay on their own ``LinearBase``s (small projections +where fusing wouldn't help much, and HF stores them separately anyway). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import jax +import jax.numpy as jnp +from flax import nnx +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +from sgl_jax.srt.layers.attention.linear.gdn_backend import GDNAttnBackend +from sgl_jax.srt.layers.linear import LinearBase, MergedColumnParallelLinear + +if TYPE_CHECKING: + from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch + + +class Qwen3_5GatedDeltaNet(nnx.Module): + """Qwen3-5 Gated DeltaNet linear-attention layer. + + ``config`` exposes ``hidden_size``, ``linear_num_value_heads``, + ``linear_num_key_heads``, ``linear_key_head_dim``, + ``linear_value_head_dim``, ``linear_conv_kernel_dim``, and + ``rms_norm_eps`` (matches HF's ``Qwen3_5TextConfig``). + """ + + def __init__( + self, + config: Any, + layer_id: int, + mamba_layer_id: int, + mesh: jax.sharding.Mesh, + dtype: jnp.dtype = jnp.bfloat16, + ): + self.layer_id = layer_id + self.mamba_layer_id = mamba_layer_id + self.mesh = mesh + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.conv_kernel_size = config.linear_conv_kernel_dim + self.eps = config.rms_norm_eps + + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_dim = 2 * self.key_dim + self.value_dim + + # Fused Q/K/V projection: one big GEMM with per-shard block-concat + # layout `[q_tp | k_tp | v_tp]` of size `conv_dim/TP` columns per + # device. Components shard independently on their own head dim, so + # GQA (where Q/K and V have different per-head sizes) doesn't cut + # a shard mid-component. + self.in_proj_qkv = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.key_dim, self.key_dim, self.value_dim], + use_bias=False, + params_dtype=dtype, + mesh=mesh, + scope_name="in_proj_qkv", + ) + # z / b / a remain separate: small projections, and HF stores them + # as independent tensors. Each is column-parallel on its own head dim. + self.in_proj_z = LinearBase( + input_size=self.hidden_size, + output_size=self.value_dim, + use_bias=False, + kernel_axes=(None, "tensor"), + params_dtype=dtype, + mesh=mesh, + scope_name="in_proj_z", + ) + self.in_proj_b = LinearBase( + input_size=self.hidden_size, + output_size=self.num_v_heads, + use_bias=False, + kernel_axes=(None, "tensor"), + params_dtype=dtype, + mesh=mesh, + scope_name="in_proj_b", + ) + self.in_proj_a = LinearBase( + input_size=self.hidden_size, + output_size=self.num_v_heads, + use_bias=False, + kernel_axes=(None, "tensor"), + params_dtype=dtype, + mesh=mesh, + scope_name="in_proj_a", + ) + + # Backend owns conv1d_weight + A_log + dt_bias and runs the + # conv+recurrence under shard_map. + # + # TODO(post-qwen3.5+kimi-linear): wrap this through + # ``RadixLinearAttention`` so all linear-attention layers in the + # repo share the same dispatch shape (matches the sglang attention + # backend pattern). The interface contract today is shaped around + # KDA's split-stream layout (three separate + # ``q_conv1d`` / ``k_conv1d`` / ``v_conv1d`` ``LinearBase`` + # containers, ``(q, k, v, a, b)`` call signature), while GDN runs + # one fused conv1d on ``mixed_qkv``. The reviewer flagged that + # the unified shape should be designed after both qwen3.5 and + # kimi-linear land, with both use cases visible — keeping the + # fused-vs-split decision open until then. + self.attention = GDNAttnBackend( + num_k_heads=self.num_k_heads, + num_v_heads=self.num_v_heads, + head_k_dim=self.head_k_dim, + head_v_dim=self.head_v_dim, + conv_kernel_size=self.conv_kernel_size, + mesh=mesh, + dtype=dtype, + ) + + # Gated GemmaRMSNorm per-head along head_v_dim. + self.rms_scale = nnx.Param(jnp.ones((self.head_v_dim,), dtype=jnp.float32)) + + # Row-parallel output projection (all-reduce across "tensor"). + self.out_proj = LinearBase( + input_size=self.value_dim, + output_size=self.hidden_size, + use_bias=False, + kernel_axes=("tensor", None), + params_dtype=dtype, + mesh=mesh, + scope_name="out_proj", + ) + + # ----- helpers ---------------------------------------------------------- + + def _rms_gate(self, core_attn_out: jax.Array, z: jax.Array) -> jax.Array: + """``Qwen3NextRMSNormGated``: ``rmsnorm(core) · γ * silu(z)``. + + Both inputs are ``[T, num_v_heads, head_v_dim]``. Order matches HF's + :class:`Qwen3NextRMSNormGated.forward` exactly so numerical-equivalence + tests against the HF reference reproduce bit-for-bit:: + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * rsqrt(variance + eps) # fp32 + hidden_states = self.weight * hidden_states.to(input_dtype) # round + promote + hidden_states = hidden_states * F.silu(gate.to(float32)) # fp32 + return hidden_states.to(input_dtype) + + The non-obvious step is the round-trip cast to ``input_dtype`` + between the norm and the gamma-multiply: HF quantizes the normalized + activations to bf16 first, then the ``fp32 * bf16`` multiply with + ``γ`` (stored fp32) promotes back to fp32. Skipping that + intermediate cast — what the previous version did — kept everything + in fp32 and produced slightly more accurate values that didn't + match HF in the last bf16 ulp. + """ + input_dtype = core_attn_out.dtype + x = core_attn_out.astype(jnp.float32) + rms = jnp.sqrt((x * x).mean(axis=-1, keepdims=True) + self.eps) + x = x / rms + # Round to input_dtype before gamma-multiply; promote the product + # back to fp32 via the fp32-stored weight so the silu-gate step + # runs in fp32 (matches HF's `self.weight * hidden_states.to(input_dtype)`). + gamma_f32 = self.rms_scale.value.astype(jnp.float32) + x = gamma_f32 * x.astype(input_dtype).astype(jnp.float32) + gated = x * jax.nn.silu(z.astype(jnp.float32)) + return gated.astype(input_dtype) + + # ----- forward ---------------------------------------------------------- + + def __call__( + self, + hidden_states: jax.Array, # [T, hidden] + forward_batch: ForwardBatch, + recurrent_state_pool, + ) -> tuple[jax.Array, jax.Array, jax.Array]: + """Returns ``(output [T, hidden], new_conv [B, conv_dim, K-1], new_rec [B, H, K, V])``. + + The backend fetches this layer's ``(recurrent_state, conv_state)`` + from ``recurrent_state_pool`` via its base class's + :meth:`get_layer_cache` helper (keyed on ``self.layer_id``) and + returns per-request new states ready for + ``RecurrentStatePool.write_layer``. + + Donation contract: ``recurrent_state_pool`` is read inside the + backend and then only the per-request new slices are emitted. The + outer jitted forward step should mark the pool buffers as + ``donate_argnames=`` (or ``donate_argnums=``) on its ``jax.jit`` + so XLA can reuse their HBM for the next step's pool. Per-layer + state copies across dozens of GDN layers are the dominant per-step + HBM traffic on large models; without donation each layer pays a + full pool copy. The caller must guarantee it does not read the + donated pool buffers after the forward step returns, or JAX will + raise ``Donated buffer has been deleted``. + """ + T = hidden_states.shape[0] + + # Fused Q/K/V via a single GEMM. Per-device output is per-shard + # block-concat `[q_tp | k_tp | v_tp]`, exactly what conv1d wants. + mixed_qkv, _ = self.in_proj_qkv(hidden_states) # [T, conv_dim] + z, _ = self.in_proj_z(hidden_states) # [T, value_dim] + b, _ = self.in_proj_b(hidden_states) # [T, num_v_heads] + a, _ = self.in_proj_a(hidden_states) # [T, num_v_heads] + + # Reshape z for the post-recurrence gate; sharding stays on n_v. + z = jax.lax.reshape( + z, + (T, self.num_v_heads, self.head_v_dim), + out_sharding=NamedSharding(self.mesh, P(None, "tensor", None)), + ) + + core_attn_out, new_conv, new_rec = self.attention( + forward_batch, + mixed_qkv, + b, + a, + recurrent_state_pool, + self.layer_id, + ) + # core_attn_out: [T, num_v_heads, head_v_dim] + + gated = self._rms_gate(core_attn_out, z) + gated_flat = jax.lax.reshape( + gated, + (T, self.num_v_heads * self.head_v_dim), + out_sharding=NamedSharding(self.mesh, P(None, "tensor")), + ) + output, _ = self.out_proj(gated_flat) + return output, new_conv, new_rec diff --git a/python/sgl_jax/srt/layers/linear.py b/python/sgl_jax/srt/layers/linear.py index 35e9e35e69..c49641f5c9 100644 --- a/python/sgl_jax/srt/layers/linear.py +++ b/python/sgl_jax/srt/layers/linear.py @@ -86,6 +86,88 @@ def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]: return out, None +class MergedColumnParallelLinear(LinearBase): + """Column-parallel linear with multiple logical outputs merged into one weight. + + Equivalent to ``N`` independent column-parallel ``LinearBase``s with the + same ``input_size`` but different ``output_size``, fused into one larger + GEMM. A single large matmul on TPU's MXU is consistently faster than ``N`` + smaller ones — fewer kernel launches, better pipelining of weight reads, + and a single MXU pass amortizes the input-side broadcast. + + Sharding contract (mirrors sglang / vLLM's ``MergedColumnParallelLinear``): + each device's local weight columns hold + ``[comp_0_my_heads | comp_1_my_heads | ...]`` block-concat. Splitting the + merged output into per-component pieces must therefore happen on + per-device data (typically inside :func:`jax.shard_map`) using **per-shard** + sizes — the global merged tensor is stripe-interleaved across devices, + not a true ``[comp_0 | comp_1 | ...]`` block-concat. + + Each entry of ``output_sizes`` must be divisible by the mesh's ``"tensor"`` + axis size so the per-shard block-concat boundary aligns with the TP cut. + Without this, GQA-style projections (where components have different + head counts) would put a shard boundary mid-component — exactly the + failure mode this layer exists to avoid. + + Weight loading is the caller's responsibility — there's no built-in + loader yet because the simple host-side scatter (collect HF tensors, + stripe them per-rank on host, single ``device_put``) costs N host + buffers and a full-tensor staging copy. A device-side scatter + (writing each HF tensor into a sharded merged param via + ``jax.lax.dynamic_update_slice`` under the right sharding context) + is the right shape for production but needs more design — left as a + follow-up. + + Args: + input_size: Input dimension. + output_sizes: Per-component output dimensions. Must each be + divisible by the mesh's ``"tensor"`` axis size. + mesh: Device mesh (must expose a ``"tensor"`` axis for sharding; + falls back to TP=1 if absent or ``mesh is None``). + use_bias / skip_bias_add / params_dtype: forwarded to ``LinearBase``. + scope_name: profiling scope. + """ + + @staticmethod + def _mesh_tp_size(mesh: jax.sharding.Mesh | None) -> int: + """TP size = mesh size on the ``"tensor"`` axis (1 if absent).""" + if mesh is None: + return 1 + shape = getattr(mesh, "shape", None) + if shape is None or "tensor" not in shape: + return 1 + return int(shape["tensor"]) + + def __init__( + self, + input_size: int, + output_sizes: Sequence[int], + mesh: jax.sharding.Mesh, + use_bias: bool = False, + skip_bias_add: bool = False, + params_dtype: jnp.dtype | None = jnp.bfloat16, + scope_name: str = "merged_column_parallel_linear", + ): + self.output_sizes = list(output_sizes) + tp_size = self._mesh_tp_size(mesh) + for i, sz in enumerate(self.output_sizes): + if sz % tp_size != 0: + raise ValueError( + f"MergedColumnParallelLinear: output_sizes[{i}]={sz} must be " + f"divisible by TP={tp_size} for clean per-shard block-concat layout." + ) + super().__init__( + input_size=input_size, + output_size=sum(self.output_sizes), + mesh=mesh, + use_bias=use_bias, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + kernel_axes=(None, "tensor"), + scope_name=scope_name, + ) + + class QuantizedLinear(nnx.Module): """Quantized linear layer using native quantized matmul. diff --git a/python/sgl_jax/test/kernels/gdn/test_gated_delta.py b/python/sgl_jax/test/kernels/gdn/test_gated_delta.py new file mode 100644 index 0000000000..09bacae340 --- /dev/null +++ b/python/sgl_jax/test/kernels/gdn/test_gated_delta.py @@ -0,0 +1,632 @@ +"""Unit tests for the primitives + kernels in :mod:`gated_delta`. + +Covers ``_l2norm``, ``_gated_delta_step``, ``jax_causal_conv1d_prefill``, +``jax_causal_conv1d_update``, and ``decode_gated_delta_rule_ref``. The +ragged kernel is covered separately in ``test_ragged_gated_delta_rule_ref.py``. + +Run with: + JAX_PLATFORMS=cpu XLA_FLAGS=--xla_force_host_platform_device_count=8 \\ + python -m pytest test/srt/test_gated_delta.py -v +""" + +from __future__ import annotations + +import os +import unittest + +os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("XLA_FLAGS", "--xla_force_host_platform_device_count=8") + +import jax +import jax.numpy as jnp +import numpy as np + +from sgl_jax.srt.kernels.gdn import ( + decode_gated_delta_rule_ref, + jax_causal_conv1d_prefill, + jax_causal_conv1d_update, + ragged_gated_delta_rule_ref, +) +from sgl_jax.srt.kernels.gdn.gated_delta import _gated_delta_step, _l2norm + + +class L2NormTest(unittest.TestCase): + def test_unit_vectors(self): + """3-4-5 triangle: ||(3,4)|| = 5, normalized = (0.6, 0.8).""" + x = jnp.array([[3.0, 4.0]], dtype=jnp.float32) + y = _l2norm(x) + np.testing.assert_allclose(y, [[0.6, 0.8]], atol=1e-6) + + def test_normalizes_along_last_axis_only(self): + """Each row gets its own norm; 2D shape preserved.""" + x = jnp.array([[1.0, 0.0], [0.0, 5.0], [3.0, 4.0]], dtype=jnp.float32) + y = _l2norm(x) + # Each row's L2 norm is ~1 (not ||x.flatten()||). + np.testing.assert_allclose(jnp.linalg.norm(y, axis=-1), [1.0, 1.0, 1.0], atol=1e-5) + + def test_zero_vector_eps_safe(self): + """Eps prevents NaN on the all-zeros input.""" + y = _l2norm(jnp.zeros((1, 4), dtype=jnp.float32)) + self.assertTrue(bool(jnp.all(jnp.isfinite(y)))) + + +class GatedDeltaStepTest(unittest.TestCase): + """Math sanity for the leading-dim-agnostic recurrence primitive.""" + + def test_zero_state_no_decay_full_beta(self): + """With S=0, β=1, g=0: new_state = k⊗v, out = (q·k)·v.""" + H, K, V = 2, 3, 4 + rng = jax.random.split(jax.random.key(0), 3) + q = jax.random.normal(rng[0], (H, K), dtype=jnp.float32) + k = jax.random.normal(rng[1], (H, K), dtype=jnp.float32) + v = jax.random.normal(rng[2], (H, V), dtype=jnp.float32) + g = jnp.zeros((H,), dtype=jnp.float32) + beta = jnp.ones((H,), dtype=jnp.float32) + state = jnp.zeros((H, K, V), dtype=jnp.float32) + + new_state, out = _gated_delta_step(state, q, k, v, g, beta) + + # k ⊗ v + expected_state = k[..., None] * v[..., None, :] + np.testing.assert_allclose(new_state, expected_state, atol=1e-5) + + # out = (state @ q) summed over K with new state + # = ((q·k) · v) per head + qk = (q * k).sum(axis=-1, keepdims=True) # [H, 1] + expected_out = qk * v + np.testing.assert_allclose(out, expected_out, atol=1e-5) + + def test_full_decay_drops_state(self): + """g = -inf (decay = 0) means prior state is forgotten — only the + new k⊗(β·v) survives.""" + H, K, V = 1, 2, 3 + rng = jax.random.split(jax.random.key(1), 4) + q = jax.random.normal(rng[0], (H, K)) + k = jax.random.normal(rng[1], (H, K)) + v = jax.random.normal(rng[2], (H, V)) + beta = jnp.full((H,), 0.5) + # Plant non-zero prior state, then decay it away. + state = jax.random.normal(rng[3], (H, K, V)) + + # exp(-100) ≈ 0 + g = jnp.full((H,), -100.0) + new_state, _ = _gated_delta_step(state, q, k, v, g, beta) + # Result should be ~ k ⊗ (β·v). + expected = k[..., None] * (beta[..., None] * v)[..., None, :] + np.testing.assert_allclose(new_state, expected, atol=1e-5) + + def test_leading_dim_agnostic_batched_equals_unbatched(self): + """Calling with [B, H, ...] batched should equal calling per-element + with [H, ...] and stacking — the function is leading-dim-agnostic.""" + B, H, K, V = 3, 2, 4, 6 + rng = jax.random.split(jax.random.key(2), 6) + state = jax.random.normal(rng[0], (B, H, K, V)) + q = jax.random.normal(rng[1], (B, H, K)) + k = jax.random.normal(rng[2], (B, H, K)) + v = jax.random.normal(rng[3], (B, H, V)) + g = jax.random.normal(rng[4], (B, H)) * 0.1 + beta = jax.random.normal(rng[5], (B, H)) * 0.5 + + # Batched call. + ns_batch, out_batch = _gated_delta_step(state, q, k, v, g, beta) + + # Per-element calls. + ns_each, out_each = [], [] + for i in range(B): + ns, o = _gated_delta_step(state[i], q[i], k[i], v[i], g[i], beta[i]) + ns_each.append(ns) + out_each.append(o) + ns_stack = jnp.stack(ns_each, axis=0) + out_stack = jnp.stack(out_each, axis=0) + + np.testing.assert_allclose(ns_batch, ns_stack, atol=1e-5) + np.testing.assert_allclose(out_batch, out_stack, atol=1e-5) + + +class CausalConv1dUpdateTest(unittest.TestCase): + """Tests for the decode-path conv1d update. + + State contract matches :func:`jax_causal_conv1d_prefill`: caller passes + the full per-layer ``conv_state`` table plus ``state_indices``; the + kernel gathers per-request state internally. Tests wrap the + ``[B, D, K-1]`` per-request state into a single-slot ``[1, D, K-1]`` + table with ``state_indices=[0]`` … or, with multi-slot scenarios, place + each request's state in a distinct slot to verify the gather. + """ + + def test_matches_window_dot(self): + """y = window · weight (per-channel) where window = [state, x_new].""" + B, D, K = 2, 3, 4 + rng = jax.random.split(jax.random.key(10), 3) + x = jax.random.normal(rng[0], (B, D), dtype=jnp.float32) + state = jax.random.normal(rng[1], (B, D, K - 1), dtype=jnp.float32) + weight = jax.random.normal(rng[2], (D, K), dtype=jnp.float32) + + # Pack the per-request state into a 1-slot-per-request pool. + conv_state = state # [B, D, K-1] doubles as a B-slot table + state_indices = jnp.arange(B, dtype=jnp.int32) + + y, new_state = jax_causal_conv1d_update(x, conv_state, state_indices, weight, bias=None) + + # Reference: window = [state | x_new], y = sum(window * weight) over K. + window = jnp.concatenate([state, x[..., None]], axis=-1) # [B, D, K] + expected_y = (window * weight[None]).sum(axis=-1) + np.testing.assert_allclose(y, expected_y, atol=1e-5) + np.testing.assert_allclose(new_state, window[..., 1:], atol=0) + + def test_silu_activation(self): + """activation='silu' applies SiLU on top of the linear output.""" + B, D, K = 1, 2, 3 + rng = jax.random.split(jax.random.key(11), 3) + x = jax.random.normal(rng[0], (B, D)) + state = jax.random.normal(rng[1], (B, D, K - 1)) + weight = jax.random.normal(rng[2], (D, K)) + + conv_state = state + state_indices = jnp.arange(B, dtype=jnp.int32) + + y_lin, _ = jax_causal_conv1d_update(x, conv_state, state_indices, weight, bias=None) + y_silu, _ = jax_causal_conv1d_update( + x, conv_state, state_indices, weight, bias=None, activation="silu" + ) + np.testing.assert_allclose(y_silu, jax.nn.silu(y_lin), atol=1e-5) + + def test_gather_picks_correct_slot(self): + """Per-request state must come from `state_indices[b]`, not from + positional alignment — verifies the gather is real.""" + D, K = 2, 3 + # Build a 4-slot pool with distinct, recognisable contents per slot. + pool = jnp.stack( + [ + jnp.full((D, K - 1), 0.0), + jnp.full((D, K - 1), 1.0), + jnp.full((D, K - 1), 2.0), + jnp.full((D, K - 1), 3.0), + ] + ) # [4, D, K-1] + weight = jnp.ones((D, K)) # straight sum + x = jnp.zeros((2, D)) # only state should contribute to output + # Two requests pulling from slots 3 and 1, respectively. + state_indices = jnp.array([3, 1], dtype=jnp.int32) + + y, new_state = jax_causal_conv1d_update(x, pool, state_indices, weight, bias=None) + # y[b, d] = sum(state[b] | x_new[b]) = sum(state[b]) since x=0. + # state[req0] = slot 3 = all 3.0 → window sum = 3 * (K-1) = 6.0. + # state[req1] = slot 1 = all 1.0 → window sum = 1 * (K-1) = 2.0. + np.testing.assert_allclose(y[0], 6.0 * jnp.ones(D), atol=1e-5) + np.testing.assert_allclose(y[1], 2.0 * jnp.ones(D), atol=1e-5) + # new_state is now the full pool table with per-request scatter-back. + # Slot 3 (req 0): old state = 3.0; window[..., 1:] keeps the K-2 newer + # taps (still 3.0) and the newest tap is the new token x = 0. + # Slot 1 (req 1): old state = 1.0; same pattern. + # Slots 0, 2 were not touched — unchanged from input pool. + np.testing.assert_allclose(new_state[3, :, :-1], 3.0, atol=1e-5) + np.testing.assert_allclose(new_state[1, :, :-1], 1.0, atol=1e-5) + np.testing.assert_allclose(new_state[3, :, -1], 0.0, atol=1e-5) + np.testing.assert_allclose(new_state[1, :, -1], 0.0, atol=1e-5) + np.testing.assert_allclose(new_state[0], 0.0, atol=1e-5) + np.testing.assert_allclose(new_state[2], 2.0, atol=1e-5) + + +class CausalConv1dPrefillTest(unittest.TestCase): + """The depthwise causal conv1d. Boundary handling is the subtle bit.""" + + def _naive_conv(self, x, weight, init_left=None): + """Reference: per-channel causal conv with optional left-pad state. + + x: [D, T], weight: [D, K], init_left: [D, K-1] or None (zero pad). + Returns y: [D, T]. + """ + D, T = x.shape + K = weight.shape[1] + left = jnp.zeros((D, K - 1), dtype=x.dtype) if init_left is None else init_left + padded = jnp.concatenate([left, x], axis=-1) # [D, T+K-1] + out = [] + for t in range(T): + window = padded[:, t : t + K] # [D, K] + out.append((window * weight).sum(axis=-1)) + return jnp.stack(out, axis=-1) # [D, T] + + def test_single_request_no_state_matches_naive(self): + D, K = 3, 3 + T = 5 + rng = jax.random.split(jax.random.key(20), 2) + x = jax.random.normal(rng[0], (D, T), dtype=jnp.float32) + weight = jax.random.normal(rng[1], (D, K), dtype=jnp.float32) + + y, final = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=None, + state_indices=None, + ) + np.testing.assert_allclose(y, self._naive_conv(x, weight), atol=1e-5) + # Final state should be the last K-1 tokens of x. + np.testing.assert_allclose(final[0], x[:, -(K - 1) :], atol=1e-5) + + def test_initial_state_carried_in(self): + """First K-1 outputs should mix in conv_state contents.""" + D, K = 2, 3 + T = 4 + rng = jax.random.split(jax.random.key(21), 3) + x = jax.random.normal(rng[0], (D, T)) + weight = jax.random.normal(rng[1], (D, K)) + # Slot 0 = null block (zeros), slot 1 carries non-zero state. + prior = jax.random.normal(rng[2], (D, K - 1)) + conv_state = jnp.zeros((2, D, K - 1)).at[1].set(prior) + + y, final = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=conv_state, + state_indices=jnp.array([1], dtype=jnp.int32), + ) + np.testing.assert_allclose(y, self._naive_conv(x, weight, init_left=prior), atol=1e-5) + # Final state lives at the request's slot in the full table (state_indices=[1]). + np.testing.assert_allclose(final[1], x[:, -(K - 1) :], atol=1e-5) + + def test_multi_request_boundary_isolation(self): + """Token 0 of request 1 must NOT see any token from request 0.""" + D, K = 1, 3 + # Two requests of length 3 each, packed into x of length 6. + # Set channel 0 of req0 to all 100 (poison) and req1 to all 1. + x = jnp.zeros((D, 6)) + x = x.at[0, :3].set(100.0) + x = x.at[0, 3:].set(1.0) + weight = jnp.ones((D, K)) # straight sum of window + + y, _ = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, 3, 6], dtype=jnp.int32), + conv_state=None, + state_indices=None, + ) + # Req1's token 0 (global idx 3) should sum [0, 0, 1] = 1, NOT 100+1+1. + # Req1's token 1 should sum [0, 1, 1] = 2. + # Req1's token 2 should sum [1, 1, 1] = 3. + np.testing.assert_allclose(y[0, 3:], [1.0, 2.0, 3.0], atol=1e-5) + # Req0 unaffected: 100, 200, 300. + np.testing.assert_allclose(y[0, :3], [100.0, 200.0, 300.0], atol=1e-5) + + def test_short_request_left_padded_from_state(self): + """A request shorter than K-1 has its final_state assembled from + (state-left-pad) + (whatever real tokens it has).""" + D, K = 1, 4 # K-1 = 3 + T = 2 # request shorter than K-1 + x = jnp.array([[10.0, 20.0]]) + weight = jnp.zeros((D, K)) # output unused, focus on final_state + prior = jnp.array([[7.0, 8.0, 9.0]]) # state[0,1,2] = newest at idx 2 + + conv_state = jnp.zeros((1, D, K - 1)).at[0].set(prior) + _, final = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=conv_state, + state_indices=jnp.array([0], dtype=jnp.int32), + ) + # State holds the K-1=3 most recent tokens BEFORE this batch, newest + # at index K-2=2. Before this batch the (logical) tail is + # ..., prior[0]=7, prior[1]=8, prior[2]=9. After appending 10, 20 + # the new most-recent K-1=3 logical tokens are [9, 10, 20]. + np.testing.assert_allclose(final[0, 0], [9.0, 10.0, 20.0], atol=1e-5) + + def test_short_request_left_padded_with_zeros_when_no_state(self): + """Same as above but with no prior state — the K-1-T missing + lookback positions zero-pad rather than pulling from a slot.""" + D, K = 1, 4 # K-1 = 3 left-pad slots needed + T = 2 + x = jnp.array([[10.0, 20.0]]) + weight = jnp.zeros((D, K)) + + _, final = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=None, + state_indices=None, + ) + # Logical stream is [pad=0, 10, 20]; last K-1=3 = [0, 10, 20]. + np.testing.assert_allclose(final[0, 0], [0.0, 10.0, 20.0], atol=1e-5) + + def test_kernel_size_1_no_lookback(self): + """K=1 means depthwise per-token multiply with no temporal mixing. + ``final_state`` has shape ``(B, D, 0)`` since there's no state to keep.""" + D, K = 2, 1 + T = 4 + rng = jax.random.split(jax.random.key(22), 2) + x = jax.random.normal(rng[0], (D, T)) + weight = jax.random.normal(rng[1], (D, K)) + + y, final = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=None, + state_indices=None, + ) + # y[d, t] = x[d, t] * weight[d, 0]. + np.testing.assert_allclose(y, x * weight, atol=1e-5) + self.assertEqual(final.shape, (1, D, 0)) + + def test_has_initial_state_false_ignores_slot(self): + """For a brand-new prefill (has_initial_state=False) the gathered + slot must be treated as zeros — output should match a fresh prefill + with no prior state, regardless of what's in the slot. + """ + D, K = 2, 3 + T = 4 + rng = jax.random.split(jax.random.key(40), 3) + x = jax.random.normal(rng[0], (D, T)) + weight = jax.random.normal(rng[1], (D, K)) + # Put non-zero "stale" data in the slot — must NOT leak into output. + stale = jax.random.normal(rng[2], (D, K - 1)) * 5.0 + conv_state = jnp.zeros((2, D, K - 1)).at[1].set(stale) + + y_masked, final_masked = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=conv_state, + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False], dtype=jnp.bool_), + ) + # Reference: same conv with NO prior state at all. + y_fresh, final_fresh = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=None, + state_indices=None, + ) + np.testing.assert_allclose(y_masked, y_fresh, atol=1e-5) + # `final_masked` is the full pool table (scatter inside kernel); pluck + # the request's slot. `final_fresh` is the per-request fallback because + # `conv_state=None` skipped the scatter. + np.testing.assert_allclose(final_masked[1], final_fresh[0], atol=1e-5) + + def test_has_initial_state_true_uses_slot(self): + """has_initial_state=True should be equivalent to the legacy + behavior (no mask) — verifies the mask is the *only* difference + introduced by the new arg. + """ + D, K = 2, 3 + T = 4 + rng = jax.random.split(jax.random.key(41), 3) + x = jax.random.normal(rng[0], (D, T)) + weight = jax.random.normal(rng[1], (D, K)) + prior = jax.random.normal(rng[2], (D, K - 1)) + conv_state = jnp.zeros((2, D, K - 1)).at[1].set(prior) + + y_with_mask, final_with_mask = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=conv_state, + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([True], dtype=jnp.bool_), + ) + y_no_arg, final_no_arg = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=conv_state, + state_indices=jnp.array([1], dtype=jnp.int32), + ) + np.testing.assert_allclose(y_with_mask, y_no_arg, atol=1e-5) + np.testing.assert_allclose(final_with_mask, final_no_arg, atol=1e-5) + + def test_has_initial_state_mixed_per_request(self): + """Two requests packed together: req0 has prior state and uses it; + req1 is a brand-new prefill and must NOT see its slot's contents.""" + D, K = 1, 3 + # Two requests of length 4 each. + x = jnp.concatenate( + [jnp.array([[1.0, 2.0, 3.0, 4.0]]), jnp.array([[5.0, 6.0, 7.0, 8.0]])], axis=-1 + ) # [D=1, T=8] + weight = jnp.ones((D, K)) # straight sum of K-window + # Slot 1 holds prior=[10, 20] for req0; slot 2 has poison [99, 99] for req1. + prior_req0 = jnp.array([[10.0, 20.0]]) + poison_req1 = jnp.array([[99.0, 99.0]]) + conv_state = jnp.zeros((3, D, K - 1)).at[1].set(prior_req0).at[2].set(poison_req1) + + y, final = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, 4, 8], dtype=jnp.int32), + conv_state=conv_state, + state_indices=jnp.array([1, 2], dtype=jnp.int32), + has_initial_state=jnp.array([True, False], dtype=jnp.bool_), + ) + # Req0 token 0 sums [prior_req0[1]=20, prior_req0[2]... wait, K-1=2 only. + # Logical pre-batch stream for req0: ..., 10, 20 (newest at idx 1). + # Window for req0[0] = [10, 20, x[0]=1] sum = 31. + # Window for req0[1] = [20, 1, 2] sum = 23. + # Window for req0[2] = [1, 2, 3] sum = 6. + # Window for req0[3] = [2, 3, 4] sum = 9. + np.testing.assert_allclose(y[0, :4], [31.0, 23.0, 6.0, 9.0], atol=1e-5) + # Req1 token 0: poison must be masked → treated as [0, 0, x[4]=5] sum = 5. + # NOT [99, 99, 5] = 203. + # Req1 token 1 = [0, 5, 6] sum = 11. + # Req1 token 2 = [5, 6, 7] sum = 18. + # Req1 token 3 = [6, 7, 8] sum = 21. + np.testing.assert_allclose(y[0, 4:], [5.0, 11.0, 18.0, 21.0], atol=1e-5) + # Final states scattered back into slots 1 and 2 (state_indices=[1, 2]). + np.testing.assert_allclose(final[1, 0], [3.0, 4.0], atol=1e-5) + np.testing.assert_allclose(final[2, 0], [7.0, 8.0], atol=1e-5) + # Slot 0 was unused — unchanged from input (zeros). + np.testing.assert_allclose(final[0], 0.0, atol=1e-5) + + def test_has_initial_state_false_short_request_zero_padded(self): + """Regression test for the bug: a request shorter than K-1 with + has_initial_state=False must produce a final_state that's + left-padded with ZEROS, not with stale slot contents.""" + D, K = 1, 4 # K-1 = 3 left-pad slots needed + T = 2 # request shorter than K-1 + x = jnp.array([[10.0, 20.0]]) + weight = jnp.zeros((D, K)) # output unused, focus on final_state + # Stale slot contents — must NOT leak into final_state. + stale = jnp.array([[77.0, 88.0, 99.0]]) + conv_state = jnp.zeros((1, D, K - 1)).at[0].set(stale) + + _, final = jax_causal_conv1d_prefill( + x, + weight, + bias=None, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + conv_state=conv_state, + state_indices=jnp.array([0], dtype=jnp.int32), + has_initial_state=jnp.array([False], dtype=jnp.bool_), + ) + # Logical stream for a fresh prefill of 2 tokens is [pad=0, 10, 20]; + # last K-1=3 = [0, 10, 20]. The stale 77/88/99 must be invisible. + np.testing.assert_allclose(final[0, 0], [0.0, 10.0, 20.0], atol=1e-5) + + +class DecodeGatedDeltaRuleRefTest(unittest.TestCase): + """The decode kernel is the parallel-across-B specialisation of the + ragged kernel. It must be numerically equivalent to running the + ragged kernel with cu_seqlens = arange(B+1).""" + + def test_matches_ragged_with_singleton_seqs(self): + """decode_gated_delta_rule_ref(...) == + ragged_gated_delta_rule_ref(cu_seqlens=arange(B+1), has_initial_state=True).""" + n_kq, n_v, d_k, d_v = 1, 2, 4, 8 + B = 5 + conv_dim = 2 * n_kq * d_k + n_v * d_v + rng = jax.random.split(jax.random.key(30), 6) + mq = jax.random.normal(rng[0], (B, conv_dim), dtype=jnp.bfloat16) * 0.3 + b = jax.random.normal(rng[1], (B, n_v), dtype=jnp.bfloat16) * 0.5 + a = jax.random.normal(rng[2], (B, n_v), dtype=jnp.bfloat16) * 0.5 + A_log = jax.random.normal(rng[3], (n_v,)) * 0.3 + dt_bias = jax.random.normal(rng[4], (n_v,)) * 0.3 + rec = jax.random.normal(rng[5], (B + 1, n_v, d_k, d_v), dtype=jnp.float32) * 0.05 + si = jnp.array([1, 2, 3, 4, 5], dtype=jnp.int32) + + nr_d, out_d = decode_gated_delta_rule_ref( + mq, + b, + a, + rec, + A_log, + dt_bias, + si, + n_kq=n_kq, + n_v=n_v, + d_k=d_k, + d_v=d_v, + ) + nr_r, out_r = ragged_gated_delta_rule_ref( + mq, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=jnp.arange(B + 1, dtype=jnp.int32), + state_indices=si, + has_initial_state=jnp.ones((B,), dtype=jnp.bool_), + n_kq=n_kq, + n_v=n_v, + d_k=d_k, + d_v=d_v, + ) + np.testing.assert_allclose(out_d, out_r, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(nr_d, nr_r, atol=1e-4, rtol=1e-4) + + def test_output_shapes(self): + """Per-request outputs at expected shapes/dtypes.""" + n_kq, n_v, d_k, d_v = 2, 4, 8, 16 + B = 3 + conv_dim = 2 * n_kq * d_k + n_v * d_v + mq = jnp.ones((B, conv_dim), dtype=jnp.bfloat16) * 0.1 + b = jnp.zeros((B, n_v), dtype=jnp.bfloat16) + a = jnp.zeros((B, n_v), dtype=jnp.bfloat16) + A_log = jnp.zeros((n_v,), dtype=jnp.float32) + dt_bias = jnp.zeros((n_v,), dtype=jnp.float32) + rec = jnp.zeros((B + 1, n_v, d_k, d_v), dtype=jnp.float32) + si = jnp.array([1, 2, 3], dtype=jnp.int32) + + new_rec, out = decode_gated_delta_rule_ref( + mq, + b, + a, + rec, + A_log, + dt_bias, + si, + n_kq=n_kq, + n_v=n_v, + d_k=d_k, + d_v=d_v, + ) + self.assertEqual(out.shape, (B, n_v, d_v)) + self.assertEqual(out.dtype, jnp.bfloat16) + # `new_rec` is the full pool table (kernel scatters internally). + self.assertEqual(new_rec.shape, (B + 1, n_v, d_k, d_v)) + self.assertEqual(new_rec.dtype, jnp.float32) + + def test_gqa_matches_ragged_with_singletons(self): + """GQA (n_v > n_kq, v_per_k > 1): decode kernel still equals the + ragged kernel with cu_seqlens=arange(B+1). Q/K head expansion must + happen inside both impls for them to agree.""" + n_kq, n_v, d_k, d_v = 2, 4, 8, 8 + B = 3 + conv_dim = 2 * n_kq * d_k + n_v * d_v + rng = jax.random.split(jax.random.key(31), 6) + mq = jax.random.normal(rng[0], (B, conv_dim), dtype=jnp.bfloat16) * 0.3 + b = jax.random.normal(rng[1], (B, n_v), dtype=jnp.bfloat16) * 0.5 + a = jax.random.normal(rng[2], (B, n_v), dtype=jnp.bfloat16) * 0.5 + A_log = jax.random.normal(rng[3], (n_v,)) * 0.3 + dt_bias = jax.random.normal(rng[4], (n_v,)) * 0.3 + rec = jax.random.normal(rng[5], (B + 1, n_v, d_k, d_v), dtype=jnp.float32) * 0.05 + si = jnp.array([1, 2, 3], dtype=jnp.int32) + + nr_d, out_d = decode_gated_delta_rule_ref( + mq, + b, + a, + rec, + A_log, + dt_bias, + si, + n_kq=n_kq, + n_v=n_v, + d_k=d_k, + d_v=d_v, + ) + nr_r, out_r = ragged_gated_delta_rule_ref( + mq, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=jnp.arange(B + 1, dtype=jnp.int32), + state_indices=si, + has_initial_state=jnp.ones((B,), dtype=jnp.bool_), + n_kq=n_kq, + n_v=n_v, + d_k=d_k, + d_v=d_v, + ) + np.testing.assert_allclose(out_d, out_r, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(nr_d, nr_r, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/sgl_jax/test/kernels/gdn/test_ragged_gated_delta_rule_ref.py b/python/sgl_jax/test/kernels/gdn/test_ragged_gated_delta_rule_ref.py new file mode 100644 index 0000000000..aa8504d5ea --- /dev/null +++ b/python/sgl_jax/test/kernels/gdn/test_ragged_gated_delta_rule_ref.py @@ -0,0 +1,485 @@ +"""Unit tests for ``ragged_gated_delta_rule_ref``. + +These tests exist primarily as a debugging aid — each case isolates one +piece of the kernel's behaviour (gating math, ragged batching, initial-state +masking, padding, GQA, decode equivalence). When a test fails, the +``_python_reference`` function runs the same recurrence in straight Python +so you can step through token-by-token and compare against the kernel. + +Run with: + JAX_PLATFORMS=cpu XLA_FLAGS=--xla_force_host_platform_device_count=8 \\ + python -m pytest test/srt/test_ragged_gated_delta_rule_ref.py -v +""" + +from __future__ import annotations + +import os +import unittest + +# Force CPU + 8 fake devices before importing JAX so explicit-axis meshes work. +os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("XLA_FLAGS", "--xla_force_host_platform_device_count=8") + +import jax +import jax.numpy as jnp +import numpy as np + +from sgl_jax.srt.kernels.gdn import ragged_gated_delta_rule_ref + +# --------------------------------------------------------------------------- +# Helpers: a straight-Python reference for one request and a small fixture. +# --------------------------------------------------------------------------- + + +def _l2norm(x, eps=1e-6): + x = x.astype(jnp.float32) + return x / jnp.sqrt((x * x).sum(axis=-1, keepdims=True) + eps) + + +def _python_reference( + mixed_qkv, # [T, 2*n_kq*d_k + n_v*d_v] + b, + a, # [T, n_v], [T, n_v] + initial_state, # [n_v, d_k, d_v] fp32 + A_log, + dt_bias, # [n_v] + n_kq, + n_v, + d_k, + d_v, +): + """Token-by-token reference for a SINGLE request. + + Mirrors the kernel's math line for line, but with a Python ``for`` loop + instead of ``lax.scan`` and no ragged-batch indexing — useful as an + independent oracle when a test fails. + """ + T = mixed_qkv.shape[0] + key_dim = n_kq * d_k + q = mixed_qkv[:, :key_dim] + k = mixed_qkv[:, key_dim : 2 * key_dim] + v = mixed_qkv[:, 2 * key_dim :] + + repeat = n_v // n_kq + A = jnp.exp(A_log.astype(jnp.float32)) + scale = d_k**-0.5 + + state = initial_state.astype(jnp.float32) + outs = [] + for t in range(T): + q_h = q[t].reshape(n_kq, d_k) + k_h = k[t].reshape(n_kq, d_k) + v_h = v[t].reshape(n_v, d_v) + if repeat > 1: + q_h = jnp.repeat(q_h, repeat, axis=0) + k_h = jnp.repeat(k_h, repeat, axis=0) + q_h = _l2norm(q_h) * scale + k_h = _l2norm(k_h) + v_h = v_h.astype(jnp.float32) + beta = jax.nn.sigmoid(b[t].astype(jnp.float32)) + g = -A * jax.nn.softplus(a[t].astype(jnp.float32) + dt_bias.astype(jnp.float32)) + + decay = jnp.exp(g)[:, None, None] + state = state * decay + kv_mem = (state * k_h[..., None]).sum(axis=-2) + delta = (v_h - kv_mem) * beta[:, None] + state = state + k_h[..., None] * delta[..., None, :] + out = (state * q_h[..., None]).sum(axis=-2) # [n_v, d_v] + outs.append(out) + output = jnp.stack(outs, axis=0) # [T, n_v, d_v] + return state, output.astype(mixed_qkv.dtype) + + +def _make_inputs(seed, total_tokens, n_kq, n_v, d_k, d_v, dtype=jnp.bfloat16): + """Random fixture: returns (mixed_qkv, b, a, A_log, dt_bias).""" + keys = jax.random.split(jax.random.key(seed), 5) + conv_dim = 2 * n_kq * d_k + n_v * d_v + mixed_qkv = jax.random.normal(keys[0], (total_tokens, conv_dim), dtype=dtype) * 0.5 + b = jax.random.normal(keys[1], (total_tokens, n_v), dtype=dtype) * 0.5 + a = jax.random.normal(keys[2], (total_tokens, n_v), dtype=dtype) * 0.5 + A_log = jax.random.normal(keys[3], (n_v,), dtype=jnp.float32) * 0.3 + dt_bias = jax.random.normal(keys[4], (n_v,), dtype=jnp.float32) * 0.3 + return mixed_qkv, b, a, A_log, dt_bias + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class RaggedGatedDeltaRuleRefTest(unittest.TestCase): + # All tests share these shapes — small enough to inspect by hand. + n_kq, n_v, d_k, d_v = 1, 2, 4, 8 + NUM_BLOCKS = 4 # state table size; slot 0 is the null block. + + # --- Case 1: single request, fresh prefill ----------------------------- + def test_single_request_matches_python_reference(self): + """B=1, has_initial_state=False — the simplest case. + + The kernel and the Python reference should produce bit-identical + outputs (same math, same dtypes) modulo bf16 rounding from the + cast on the way out. + """ + T = 6 + mixed_qkv, b, a, A_log, dt_bias = _make_inputs( + 0, T, self.n_kq, self.n_v, self.d_k, self.d_v + ) + rec = jnp.zeros((self.NUM_BLOCKS, self.n_v, self.d_k, self.d_v), dtype=jnp.float32) + + new_rec, out = ragged_gated_delta_rule_ref( + mixed_qkv, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False], dtype=jnp.bool_), + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + + ref_state, ref_out = _python_reference( + mixed_qkv, + b, + a, + jnp.zeros((self.n_v, self.d_k, self.d_v), dtype=jnp.float32), + A_log, + dt_bias, + self.n_kq, + self.n_v, + self.d_k, + self.d_v, + ) + + np.testing.assert_allclose(out, ref_out, atol=1e-3, rtol=1e-3) + # `new_rec` is the full pool table (kernel scatters internally). + # state_indices=[1] → the updated slot is index 1. + np.testing.assert_allclose(new_rec[1], ref_state, atol=1e-4, rtol=1e-4) + + # --- Case 2: with initial state ---------------------------------------- + def test_initial_state_is_picked_up(self): + """has_initial_state=True — the kernel should gather slot's prior state. + + Plant a non-zero state in slot 2, run with has_initial_state=True + vs False, expect different outputs (False zeroes it). + """ + T = 4 + mixed_qkv, b, a, A_log, dt_bias = _make_inputs( + 1, T, self.n_kq, self.n_v, self.d_k, self.d_v + ) + rec = jnp.zeros((self.NUM_BLOCKS, self.n_v, self.d_k, self.d_v), dtype=jnp.float32) + # Plant a deterministic prior state in slot 2. + prior = jax.random.normal( + jax.random.key(99), (self.n_v, self.d_k, self.d_v), dtype=jnp.float32 + ) + rec = rec.at[2].set(prior) + + common = dict( + mixed_qkv=mixed_qkv, + b=b, + a=a, + A_log=A_log, + dt_bias=dt_bias, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + state_indices=jnp.array([2], dtype=jnp.int32), + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + # With has_initial_state=True: prior should flow in. + _, out_with = ragged_gated_delta_rule_ref( + recurrent_state=rec, + has_initial_state=jnp.array([True]), + **common, + ) + ref_with = _python_reference( + mixed_qkv, + b, + a, + prior, + A_log, + dt_bias, + self.n_kq, + self.n_v, + self.d_k, + self.d_v, + )[1] + np.testing.assert_allclose(out_with, ref_with, atol=1e-3, rtol=1e-3) + + # With has_initial_state=False: prior should be ignored. + _, out_without = ragged_gated_delta_rule_ref( + recurrent_state=rec, + has_initial_state=jnp.array([False]), + **common, + ) + ref_without = _python_reference( + mixed_qkv, + b, + a, + jnp.zeros_like(prior), + A_log, + dt_bias, + self.n_kq, + self.n_v, + self.d_k, + self.d_v, + )[1] + np.testing.assert_allclose(out_without, ref_without, atol=1e-3, rtol=1e-3) + + # Sanity: the two should not coincide for non-zero prior. + self.assertFalse(jnp.allclose(out_with, out_without, atol=1e-3)) + + # --- Case 3: ragged batching independence ------------------------------ + def test_ragged_batching_matches_per_request_runs(self): + """Two reqs of different lengths in one packed batch produce the + same outputs as running each request independently. + + If this test passes, sequence boundaries are clean: no token from + req-0 ever leaks into req-1's recurrent state and vice versa. + """ + lens = [3, 5] + T = sum(lens) + mixed_qkv, b, a, A_log, dt_bias = _make_inputs( + 2, T, self.n_kq, self.n_v, self.d_k, self.d_v + ) + rec = jnp.zeros((self.NUM_BLOCKS, self.n_v, self.d_k, self.d_v), dtype=jnp.float32) + cu = jnp.array([0, lens[0], lens[0] + lens[1]], dtype=jnp.int32) + + new_rec, out = ragged_gated_delta_rule_ref( + mixed_qkv, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=cu, + state_indices=jnp.array([1, 2], dtype=jnp.int32), + has_initial_state=jnp.array([False, False]), + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + + # Run each request through the Python reference and stack. + zero_state = jnp.zeros((self.n_v, self.d_k, self.d_v), dtype=jnp.float32) + ref_state_0, ref_out_0 = _python_reference( + mixed_qkv[: lens[0]], + b[: lens[0]], + a[: lens[0]], + zero_state, + A_log, + dt_bias, + self.n_kq, + self.n_v, + self.d_k, + self.d_v, + ) + ref_state_1, ref_out_1 = _python_reference( + mixed_qkv[lens[0] :], + b[lens[0] :], + a[lens[0] :], + zero_state, + A_log, + dt_bias, + self.n_kq, + self.n_v, + self.d_k, + self.d_v, + ) + ref_out = jnp.concatenate([ref_out_0, ref_out_1], axis=0) + np.testing.assert_allclose(out, ref_out, atol=1e-3, rtol=1e-3) + # `new_rec` is the full pool table; pluck the per-request slots + # (state_indices=[1, 2]). + np.testing.assert_allclose(new_rec[1], ref_state_0, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(new_rec[2], ref_state_1, atol=1e-4, rtol=1e-4) + + # --- Case 4: padding tokens are ignored -------------------------------- + def test_padding_tokens_do_not_mutate_state(self): + """Tokens beyond cu_seqlens[-1] are padding; their writes are masked + off by ``valid_mask``, so the per-seq new state should equal the + no-padding result. + """ + real_T = 4 + pad = 3 + T = real_T + pad + mixed_qkv, b, a, A_log, dt_bias = _make_inputs( + 3, T, self.n_kq, self.n_v, self.d_k, self.d_v + ) + rec = jnp.zeros((self.NUM_BLOCKS, self.n_v, self.d_k, self.d_v), dtype=jnp.float32) + + new_rec_padded, _ = ragged_gated_delta_rule_ref( + mixed_qkv, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, real_T], dtype=jnp.int32), # last_valid_loc = 4 + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False]), + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + new_rec_clean, _ = ragged_gated_delta_rule_ref( + mixed_qkv[:real_T], + b[:real_T], + a[:real_T], + rec, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, real_T], dtype=jnp.int32), + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False]), + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + np.testing.assert_allclose(new_rec_padded[0], new_rec_clean[0], atol=1e-5, rtol=1e-5) + + # --- Case 5: GQA expansion (n_v > n_kq) -------------------------------- + def test_gqa_expansion_matches_pre_repeated(self): + """When n_v > n_kq, the kernel repeats Q/K to n_v heads internally. + + Run with (n_kq=2, n_v=4) and compare against running with q/k + manually pre-repeated and (n_kq=4, n_v=4). Outputs should match. + """ + n_kq, n_v, d_k, d_v = 2, 4, 4, 6 + repeat = n_v // n_kq + T = 5 + key_dim = n_kq * d_k + keys = jax.random.split(jax.random.key(4), 5) + mixed_qkv_in = ( + jax.random.normal(keys[0], (T, 2 * key_dim + n_v * d_v), dtype=jnp.bfloat16) * 0.5 + ) + b = jax.random.normal(keys[1], (T, n_v), dtype=jnp.bfloat16) * 0.5 + a = jax.random.normal(keys[2], (T, n_v), dtype=jnp.bfloat16) * 0.5 + A_log = jax.random.normal(keys[3], (n_v,), dtype=jnp.float32) * 0.3 + dt_bias = jax.random.normal(keys[4], (n_v,), dtype=jnp.float32) * 0.3 + rec = jnp.zeros((2, n_v, d_k, d_v), dtype=jnp.float32) + + new_rec_a, out_a = ragged_gated_delta_rule_ref( + mixed_qkv_in, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False]), + n_kq=n_kq, + n_v=n_v, + d_k=d_k, + d_v=d_v, + ) + + # Now build the equivalent (n_kq=n_v) input: repeat q and k across + # heads, leave v unchanged. + q = mixed_qkv_in[:, :key_dim].reshape(T, n_kq, d_k) + k = mixed_qkv_in[:, key_dim : 2 * key_dim].reshape(T, n_kq, d_k) + v = mixed_qkv_in[:, 2 * key_dim :] + q_rep = jnp.repeat(q, repeat, axis=1).reshape(T, n_v * d_k) + k_rep = jnp.repeat(k, repeat, axis=1).reshape(T, n_v * d_k) + mixed_qkv_eq = jnp.concatenate([q_rep, k_rep, v], axis=-1) + new_rec_b, out_b = ragged_gated_delta_rule_ref( + mixed_qkv_eq, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False]), + n_kq=n_v, + n_v=n_v, + d_k=d_k, + d_v=d_v, + ) + np.testing.assert_allclose(out_a, out_b, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(new_rec_a, new_rec_b, atol=1e-4, rtol=1e-4) + + # --- Case 6: prefill-then-decode equivalence --------------------------- + def test_prefill_then_decode_equals_full_prefill(self): + """Run T tokens as one prefill, vs the same T tokens broken into a + prefill (T-1 tokens) followed by a single decode step. The final + state and last output should match. + + This is the contract decode mode relies on: continuing a sequence + token-by-token from a saved state must produce the same trajectory + as scanning the whole thing in one go. + """ + T = 6 + mixed_qkv, b, a, A_log, dt_bias = _make_inputs( + 5, T, self.n_kq, self.n_v, self.d_k, self.d_v + ) + rec = jnp.zeros((self.NUM_BLOCKS, self.n_v, self.d_k, self.d_v), dtype=jnp.float32) + + # Path A: full prefill. + new_rec_full, out_full = ragged_gated_delta_rule_ref( + mixed_qkv, + b, + a, + rec, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, T], dtype=jnp.int32), + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False]), + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + + # Path B: prefill T-1 tokens, then one decode step. + new_rec_pref, _ = ragged_gated_delta_rule_ref( + mixed_qkv[: T - 1], + b[: T - 1], + a[: T - 1], + rec, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, T - 1], dtype=jnp.int32), + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([False]), + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + # `new_rec_pref` is already the full pool table with slot 1 updated — + # pass it straight in as the state table for the decode step. + new_rec_dec, out_dec = ragged_gated_delta_rule_ref( + mixed_qkv[T - 1 :], + b[T - 1 :], + a[T - 1 :], + new_rec_pref, + A_log, + dt_bias, + cu_seqlens=jnp.array([0, 1], dtype=jnp.int32), + state_indices=jnp.array([1], dtype=jnp.int32), + has_initial_state=jnp.array([True]), # continuation + n_kq=self.n_kq, + n_v=self.n_v, + d_k=self.d_k, + d_v=self.d_v, + ) + + # Last-token output and final state should coincide; per-request slot + # is index 1 in the full pool table. + np.testing.assert_allclose(out_full[-1], out_dec[0], atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(new_rec_full[1], new_rec_dec[1], atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/sgl_jax/test/layers/test_gdn_backend.py b/python/sgl_jax/test/layers/test_gdn_backend.py new file mode 100644 index 0000000000..4b7dcd8111 --- /dev/null +++ b/python/sgl_jax/test/layers/test_gdn_backend.py @@ -0,0 +1,290 @@ +"""Unit tests for :class:`GDNAttnBackend`. + +The kernels themselves are covered in ``test_gated_delta.py`` and +``test_ragged_gated_delta_rule_ref.py``; this file exercises the backend +glue: ``__init__`` parameter ownership, decode/extend dispatch, and the +``shard_map``-wrapped conv + recurrence pipeline. + +The backend inherits ``LinearRecurrentAttnBackend`` and reads +``cu_q_lens`` / ``recurrent_indices`` / ``has_initial_state`` from +``self.forward_metadata`` (normally populated by +``get_forward_metadata(batch)`` before the forward; we set it directly +here for unit testing). State is fetched from a +``recurrent_state_pool``-shaped object via ``get_layer_cache``. + +Run with: + JAX_PLATFORMS=cpu XLA_FLAGS=--xla_force_host_platform_device_count=8 \\ + python -m pytest test/srt/test_gdn_backend.py -v +""" + +from __future__ import annotations + +import os +import unittest + +os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("XLA_FLAGS", "--xla_force_host_platform_device_count=8") + +import jax +import jax.numpy as jnp +from flax import nnx +from jax.experimental import mesh_utils +from jax.sharding import AxisType, Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +from sgl_jax.srt.layers.attention.hybrid_linear_attn_backend import ( + LinearRecurrentAttnBackendMetadata, +) +from sgl_jax.srt.layers.attention.linear.gdn_backend import GDNAttnBackend + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_mesh(): + """Single-device 1×1 mesh with both 'data' and 'tensor' axes.""" + devices = mesh_utils.create_device_mesh((8,))[:1].reshape((1, 1)) + return Mesh( + devices, + ("data", "tensor"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +def _make_backend(mesh, n_kq=1, n_v=2, d_k=4, d_v=8, K=3): + backend = GDNAttnBackend( + num_k_heads=n_kq, + num_v_heads=n_v, + head_k_dim=d_k, + head_v_dim=d_v, + conv_kernel_size=K, + mesh=mesh, + dtype=jnp.bfloat16, + ) + rng = jax.random.split(jax.random.key(0), 3) + backend.conv1d_weight = nnx.Param( + jax.random.normal(rng[0], (backend.conv_dim, K), dtype=jnp.bfloat16) * 0.1 + ) + backend.A_log = nnx.Param(jax.random.normal(rng[1], (n_v,)) * 0.3) + backend.dt_bias = nnx.Param(jax.random.normal(rng[2], (n_v,)) * 0.3) + return backend + + +class _FakeForwardMode: + def __init__(self, decode: bool): + self._decode = decode + + def is_decode(self): + return self._decode + + +class _FakeForwardBatch: + """Minimal forward batch — the new backend only reads forward_mode here. + + All ragged-batch info (cu_seqlens, recurrent_indices, has_initial_state) + is read from ``backend.forward_metadata`` instead, which the tests set + directly via ``_set_metadata``. + """ + + def __init__(self, is_decode: bool): + self.forward_mode = _FakeForwardMode(is_decode) + + +class _FakePool: + """Stand-in for :class:`RecurrentStatePool` that exposes the single + method ``get_linear_recurrent_layer_cache`` the backend uses. + + Holds one ``recurrent_buffer`` and a one-element conv-buffer list per + layer (GDN has a single fused conv per layer; KDA would have three). + """ + + def __init__(self, recurrent_buffer, conv_buffer): + self._rec = recurrent_buffer + self._conv_list = [conv_buffer] + + def get_linear_recurrent_layer_cache(self, layer_id: int): + return self._rec, self._conv_list + + +def _set_metadata(backend, cu_q_lens=None, recurrent_indices=None, has_initial_state=None): + backend.forward_metadata = LinearRecurrentAttnBackendMetadata( + cu_q_lens=cu_q_lens, + recurrent_indices=recurrent_indices, + has_initial_state=has_initial_state, + ) + + +def _sharded_state(mesh, shape, spec, dtype, rng=None): + """Allocate an array with explicit sharding (matches what + ``RecurrentStatePool`` produces in production).""" + out_sharding = NamedSharding(mesh, spec) + if rng is None: + return jnp.zeros(shape, dtype=dtype, out_sharding=out_sharding) + return jax.random.normal(rng, shape, dtype=dtype, out_sharding=out_sharding) + + +def _sharded_proj(mesh, shape, spec, rng, scale=0.3): + """Mimic a `LinearBase` output: sharded on the last (head/channel) axis.""" + return ( + jax.random.normal( + rng, + shape, + dtype=jnp.bfloat16, + out_sharding=NamedSharding(mesh, spec), + ) + * scale + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class GDNAttnBackendInitTest(unittest.TestCase): + def test_param_shapes_match_config(self): + """``conv_dim`` is derived from head counts/dims; param shapes track it.""" + mesh = _make_mesh() + with jax.set_mesh(mesh): + n_kq, n_v, d_k, d_v, K = 2, 4, 8, 16, 4 + backend = GDNAttnBackend( + num_k_heads=n_kq, + num_v_heads=n_v, + head_k_dim=d_k, + head_v_dim=d_v, + conv_kernel_size=K, + mesh=mesh, + dtype=jnp.bfloat16, + ) + self.assertEqual(backend.key_dim, n_kq * d_k) + self.assertEqual(backend.value_dim, n_v * d_v) + self.assertEqual(backend.conv_dim, 2 * backend.key_dim + backend.value_dim) + self.assertEqual(backend.conv1d_weight.value.shape, (backend.conv_dim, K)) + self.assertEqual(backend.A_log.value.shape, (n_v,)) + self.assertEqual(backend.dt_bias.value.shape, (n_v,)) + + +class GDNAttnBackendDispatchTest(unittest.TestCase): + """``__call__`` routes to forward_decode/forward_extend on forward_mode.""" + + def _make_inputs(self, mesh, T, backend, rng_root): + rng = jax.random.split(rng_root, 3) + mixed_qkv = _sharded_proj(mesh, (T, backend.conv_dim), P(None, "tensor"), rng[0]) + b = _sharded_proj(mesh, (T, backend.num_v_heads), P(None, "tensor"), rng[1]) + a = _sharded_proj(mesh, (T, backend.num_v_heads), P(None, "tensor"), rng[2]) + return mixed_qkv, b, a + + def _make_pool(self, mesh, backend, B): + cs = _sharded_state( + mesh, + (B + 1, backend.conv_dim, backend.conv_kernel_size - 1), + P(None, "tensor", None), + jnp.bfloat16, + ) + rs = _sharded_state( + mesh, + (B + 1, backend.num_v_heads, backend.head_k_dim, backend.head_v_dim), + P(None, "tensor", None, None), + jnp.float32, + ) + return _FakePool(recurrent_buffer=rs, conv_buffer=cs) + + def test_decode_dispatch(self): + mesh = _make_mesh() + with jax.set_mesh(mesh): + backend = _make_backend(mesh) + B = 2 + mixed_qkv, b, a = self._make_inputs(mesh, B, backend, jax.random.key(1)) + pool = self._make_pool(mesh, backend, B) + _set_metadata( + backend, + recurrent_indices=jnp.array([1, 2], dtype=jnp.int32), + ) + fb = _FakeForwardBatch(is_decode=True) + out, new_conv, new_rec = backend(fb, mixed_qkv, b, a, pool, layer_id=0) + self.assertEqual(out.shape, (B, backend.num_v_heads, backend.head_v_dim)) + # `new_conv` / `new_rec` are the full pool tables (kernel scatters + # per-request slots in place). Pool was sized `B + 1`. + self.assertEqual( + new_conv.shape, (B + 1, backend.conv_dim, backend.conv_kernel_size - 1) + ) + self.assertEqual( + new_rec.shape, (B + 1, backend.num_v_heads, backend.head_k_dim, backend.head_v_dim) + ) + + def test_extend_dispatch(self): + mesh = _make_mesh() + with jax.set_mesh(mesh): + backend = _make_backend(mesh) + T = 5 # 2 reqs of lengths [3, 2] + B = 2 + mixed_qkv, b, a = self._make_inputs(mesh, T, backend, jax.random.key(2)) + pool = self._make_pool(mesh, backend, B) + _set_metadata( + backend, + cu_q_lens=jnp.array([0, 3, 5], dtype=jnp.int32), + recurrent_indices=jnp.array([1, 2], dtype=jnp.int32), + has_initial_state=jnp.array([False, False], dtype=jnp.bool_), + ) + fb = _FakeForwardBatch(is_decode=False) + out, new_conv, new_rec = backend(fb, mixed_qkv, b, a, pool, layer_id=0) + self.assertEqual(out.shape, (T, backend.num_v_heads, backend.head_v_dim)) + self.assertEqual( + new_conv.shape, (B + 1, backend.conv_dim, backend.conv_kernel_size - 1) + ) + self.assertEqual( + new_rec.shape, (B + 1, backend.num_v_heads, backend.head_k_dim, backend.head_v_dim) + ) + + +class GDNAttnBackendExtendStateTest(unittest.TestCase): + """forward_extend should return per-request new states with finite values.""" + + def test_extend_returns_per_request_state_shape(self): + mesh = _make_mesh() + with jax.set_mesh(mesh): + backend = _make_backend(mesh) + lens = [4, 2, 1] + T = sum(lens) + B = len(lens) + rng = jax.random.split(jax.random.key(50), 3) + mixed_qkv = _sharded_proj(mesh, (T, backend.conv_dim), P(None, "tensor"), rng[0]) + b = _sharded_proj(mesh, (T, backend.num_v_heads), P(None, "tensor"), rng[1]) + a = _sharded_proj(mesh, (T, backend.num_v_heads), P(None, "tensor"), rng[2]) + cs = _sharded_state( + mesh, + (B + 1, backend.conv_dim, backend.conv_kernel_size - 1), + P(None, "tensor", None), + jnp.bfloat16, + rng=jax.random.key(60), + ) + rs = _sharded_state( + mesh, + (B + 1, backend.num_v_heads, backend.head_k_dim, backend.head_v_dim), + P(None, "tensor", None, None), + jnp.float32, + rng=jax.random.key(61), + ) + pool = _FakePool(recurrent_buffer=rs, conv_buffer=cs) + _set_metadata( + backend, + cu_q_lens=jnp.array([0, 4, 6, 7], dtype=jnp.int32), + recurrent_indices=jnp.array([1, 2, 3], dtype=jnp.int32), + has_initial_state=jnp.array([False, False, False], dtype=jnp.bool_), + ) + fb = _FakeForwardBatch(is_decode=False) + out, new_conv, new_rec = backend(fb, mixed_qkv, b, a, pool, layer_id=0) + self.assertEqual(out.shape, (T, backend.num_v_heads, backend.head_v_dim)) + self.assertEqual( + new_conv.shape, (B + 1, backend.conv_dim, backend.conv_kernel_size - 1) + ) + self.assertEqual( + new_rec.shape, (B + 1, backend.num_v_heads, backend.head_k_dim, backend.head_v_dim) + ) + self.assertTrue(bool(jnp.all(jnp.isfinite(out)))) + self.assertTrue(bool(jnp.all(jnp.isfinite(new_rec)))) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/sgl_jax/test/layers/test_merged_column_parallel_linear.py b/python/sgl_jax/test/layers/test_merged_column_parallel_linear.py new file mode 100644 index 0000000000..9293eb7743 --- /dev/null +++ b/python/sgl_jax/test/layers/test_merged_column_parallel_linear.py @@ -0,0 +1,90 @@ +"""Unit tests for :class:`MergedColumnParallelLinear`. + +The primitive is a generalised drop-in for a stack of independent +column-parallel ``LinearBase``s fused into one bigger GEMM (better MXU +utilization on TPU). Weight loading is the caller's responsibility (no +built-in loader yet — see class docstring). The forward identity is +just ``LinearBase``'s matmul, so tests focus on: + +* the merged weight has shape ``[input_size, sum(output_sizes)]``; +* default no-bias behaviour matches ``LinearBase``; +* construction rejects component sizes that don't divide TP — the + divisibility guard the per-rank block-concat layout depends on. + +Run with: + JAX_PLATFORMS=cpu XLA_FLAGS=--xla_force_host_platform_device_count=8 \\ + python -m pytest test/srt/test_merged_column_parallel_linear.py -v +""" + +from __future__ import annotations + +import os +import unittest + +os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("XLA_FLAGS", "--xla_force_host_platform_device_count=8") + +import jax +from jax.experimental import mesh_utils +from jax.sharding import AxisType, Mesh + +from sgl_jax.srt.layers.linear import MergedColumnParallelLinear + + +def _mesh_1x1(): + devices = mesh_utils.create_device_mesh((8,))[:1].reshape((1, 1)) + return Mesh( + devices, + ("data", "tensor"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +def _mesh_1xN(n: int): + devices = mesh_utils.create_device_mesh((8,))[:n].reshape((1, n)) + return Mesh( + devices, + ("data", "tensor"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +class MergedColumnParallelInitTest(unittest.TestCase): + def test_weight_shape_is_sum_of_output_sizes(self): + """Single merged weight of width ``sum(output_sizes)``.""" + mesh = _mesh_1x1() + with jax.set_mesh(mesh): + layer = MergedColumnParallelLinear( + input_size=32, + output_sizes=[64, 64, 128], + mesh=mesh, + ) + self.assertEqual(layer.weight.value.shape, (32, 64 + 64 + 128)) + self.assertEqual(layer.output_sizes, [64, 64, 128]) + + def test_no_bias_by_default(self): + mesh = _mesh_1x1() + with jax.set_mesh(mesh): + layer = MergedColumnParallelLinear( + input_size=8, + output_sizes=[4, 4], + mesh=mesh, + ) + self.assertIsNone(layer.bias) + + def test_rejects_non_divisible_component(self): + """Each component size must independently divide TP, so the + per-rank block-concat boundary aligns with the TP cut.""" + mesh = _mesh_1xN(2) + with jax.set_mesh(mesh): + with self.assertRaises(ValueError) as ctx: + MergedColumnParallelLinear( + input_size=16, + output_sizes=[3, 4], + mesh=mesh, # 3 % 2 == 1 + ) + self.assertIn("divisible by TP=2", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/sgl_jax/test/layers/test_qwen3_5_gated_delta_net.py b/python/sgl_jax/test/layers/test_qwen3_5_gated_delta_net.py new file mode 100644 index 0000000000..a58289a70c --- /dev/null +++ b/python/sgl_jax/test/layers/test_qwen3_5_gated_delta_net.py @@ -0,0 +1,247 @@ +"""Unit tests for :class:`Qwen3_5GatedDeltaNet`. + +The class wraps :class:`GDNAttnBackend` with the HF Qwen3-5 projection +structure (six independent linears: q/k/v/z/b/a), a gated GemmaRMSNorm, +and an ``out_proj``. The backend + kernels are covered separately; this +file exercises the glue: param shapes, the RMS gate math, and the +end-to-end shape/dtype contract. + +Run with: + JAX_PLATFORMS=cpu XLA_FLAGS=--xla_force_host_platform_device_count=8 \\ + python -m pytest test/srt/test_qwen3_5_gated_delta_net.py -v +""" + +from __future__ import annotations + +import os +import unittest +from types import SimpleNamespace + +os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("XLA_FLAGS", "--xla_force_host_platform_device_count=8") + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax.experimental import mesh_utils +from jax.sharding import AxisType, Mesh, NamedSharding +from jax.sharding import PartitionSpec as P + +from sgl_jax.srt.layers.attention.hybrid_linear_attn_backend import ( + LinearRecurrentAttnBackendMetadata, +) +from sgl_jax.srt.layers.attention.linear.qwen3_5_gated_delta_net import ( + Qwen3_5GatedDeltaNet, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_mesh(): + devices = mesh_utils.create_device_mesh((8,))[:1].reshape((1, 1)) + return Mesh( + devices, + ("data", "tensor"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +def _make_config(hidden_size=64, n_kq=1, n_v=2, d_k=4, d_v=8, K=3, eps=1e-6): + return SimpleNamespace( + hidden_size=hidden_size, + linear_num_value_heads=n_v, + linear_num_key_heads=n_kq, + linear_key_head_dim=d_k, + linear_value_head_dim=d_v, + linear_conv_kernel_dim=K, + rms_norm_eps=eps, + ) + + +class _FakeForwardMode: + def __init__(self, decode: bool): + self._decode = decode + + def is_decode(self): + return self._decode + + +class _FakeForwardBatch: + """Minimal forward batch — the layer only reads forward_mode here. + + All ragged-batch info now lives on ``backend.forward_metadata``; the + test sets that directly. + """ + + def __init__(self, is_decode: bool): + self.forward_mode = _FakeForwardMode(is_decode) + + +class _FakePool: + """Stand-in for :class:`RecurrentStatePool` exposing only + ``get_linear_recurrent_layer_cache``.""" + + def __init__(self, recurrent_buffer, conv_buffer): + self._rec = recurrent_buffer + self._conv_list = [conv_buffer] + + def get_linear_recurrent_layer_cache(self, layer_id: int): + return self._rec, self._conv_list + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class Qwen3_5GatedDeltaNetInitTest(unittest.TestCase): + def test_projection_shapes_match_hf_layout(self): + """``in_proj_qkv`` is a merged Q/K/V GEMM; z/b/a stay separate. Param + shapes match the HF Qwen3-5 checkpoint structure.""" + mesh = _make_mesh() + with jax.set_mesh(mesh): + cfg = _make_config(hidden_size=32, n_kq=2, n_v=4, d_k=8, d_v=8, K=3) + layer = Qwen3_5GatedDeltaNet(cfg, layer_id=0, mamba_layer_id=0, mesh=mesh) + key_dim = cfg.linear_num_key_heads * cfg.linear_key_head_dim + value_dim = cfg.linear_num_value_heads * cfg.linear_value_head_dim + qkv_total = 2 * key_dim + value_dim + + # Fused Q/K/V projection: single weight tensor of size 2*key+value. + self.assertEqual(layer.in_proj_qkv.weight.value.shape, (cfg.hidden_size, qkv_total)) + self.assertEqual(layer.in_proj_qkv.output_sizes, [key_dim, key_dim, value_dim]) + # z/b/a stay separate. + self.assertEqual(layer.in_proj_z.weight.value.shape, (cfg.hidden_size, value_dim)) + self.assertEqual( + layer.in_proj_b.weight.value.shape, (cfg.hidden_size, cfg.linear_num_value_heads) + ) + self.assertEqual( + layer.in_proj_a.weight.value.shape, (cfg.hidden_size, cfg.linear_num_value_heads) + ) + self.assertEqual(layer.out_proj.weight.value.shape, (value_dim, cfg.hidden_size)) + self.assertEqual( + layer.attention.conv1d_weight.value.shape, + (layer.conv_dim, cfg.linear_conv_kernel_dim), + ) + self.assertEqual(layer.attention.A_log.value.shape, (cfg.linear_num_value_heads,)) + self.assertEqual(layer.attention.dt_bias.value.shape, (cfg.linear_num_value_heads,)) + self.assertEqual(layer.rms_scale.value.shape, (cfg.linear_value_head_dim,)) + + +class Qwen3_5GatedDeltaNetRmsGateTest(unittest.TestCase): + """``_rms_gate`` = ``RMSNorm(core)·γ * silu(z)``.""" + + def test_matches_explicit_formula(self): + mesh = _make_mesh() + with jax.set_mesh(mesh): + cfg = _make_config(n_v=2, d_v=4, eps=1e-6) + layer = Qwen3_5GatedDeltaNet(cfg, 0, 0, mesh) + gamma = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=jnp.float32) + layer.rms_scale = nnx.Param(gamma) + + T, n_v, d_v = 3, 2, 4 + rng = jax.random.split(jax.random.key(7), 2) + core = jax.random.normal(rng[0], (T, n_v, d_v), dtype=jnp.bfloat16) * 0.5 + z = jax.random.normal(rng[1], (T, n_v, d_v), dtype=jnp.bfloat16) * 0.5 + + got = layer._rms_gate(core, z) + + # Reference RMS over the last axis in fp32. + xf = core.astype(jnp.float32) + rms = jnp.sqrt((xf * xf).mean(axis=-1, keepdims=True) + 1e-6) + ref = (xf / rms) * gamma * jax.nn.silu(z.astype(jnp.float32)) + ref = ref.astype(core.dtype) + np.testing.assert_allclose(got, ref, atol=1e-2, rtol=1e-2) + + +class Qwen3_5GatedDeltaNetEndToEndTest(unittest.TestCase): + """End-to-end shape/dtype/finite contract. + + Numerical correctness lives in ``test_gated_delta.py`` and + ``test_ragged_gated_delta_rule_ref.py`` (which exercise the kernels + directly without a mesh). Here we just confirm the layer wires + projections → conv → recurrence → RMS gate → out_proj without + sharding errors and the output shape matches ``hidden_size``. + """ + + def _run_layer(self, is_decode): + mesh = _make_mesh() + with jax.set_mesh(mesh): + cfg = _make_config(hidden_size=32, n_kq=1, n_v=2, d_k=4, d_v=8, K=3) + layer = Qwen3_5GatedDeltaNet(cfg, 0, 0, mesh) + conv_dim = layer.conv_dim + layer.attention.conv1d_weight = nnx.Param( + jax.random.normal( + jax.random.key(0), (conv_dim, cfg.linear_conv_kernel_dim), dtype=jnp.bfloat16 + ) + * 0.05 + ) + + # 3 reqs of length [3, 2] for extend, B=3 for decode. + if is_decode: + T = 3 + B = 3 + cu_q_lens = None + has_initial_state = None + else: + T = 5 + B = 2 + cu_q_lens = jnp.array([0, 3, 5], dtype=jnp.int32) + has_initial_state = jnp.array([False, False], dtype=jnp.bool_) + + hidden = ( + jax.random.normal(jax.random.key(1), (T, cfg.hidden_size), dtype=jnp.bfloat16) * 0.3 + ) + conv_state = jnp.zeros( + (B + 1, conv_dim, cfg.linear_conv_kernel_dim - 1), + dtype=jnp.bfloat16, + out_sharding=NamedSharding(mesh, P(None, "tensor", None)), + ) + rec_state = jnp.zeros( + ( + B + 1, + cfg.linear_num_value_heads, + cfg.linear_key_head_dim, + cfg.linear_value_head_dim, + ), + dtype=jnp.float32, + out_sharding=NamedSharding(mesh, P(None, "tensor", None, None)), + ) + pool = _FakePool(recurrent_buffer=rec_state, conv_buffer=conv_state) + layer.attention.forward_metadata = LinearRecurrentAttnBackendMetadata( + cu_q_lens=cu_q_lens, + recurrent_indices=jnp.arange(1, B + 1, dtype=jnp.int32), + has_initial_state=has_initial_state, + ) + fb = _FakeForwardBatch(is_decode=is_decode) + return layer(hidden, fb, pool), B, T, cfg, conv_dim + + def test_decode_path(self): + (out, new_conv, new_rec), B, T, cfg, conv_dim = self._run_layer(is_decode=True) + self.assertEqual(out.shape, (T, cfg.hidden_size)) + self.assertEqual(out.dtype, jnp.bfloat16) + # `new_conv` / `new_rec` are full pool tables (kernel scatters + # per-request slots in place). Fake pool was sized `B + 1`. + self.assertEqual(new_conv.shape, (B + 1, conv_dim, cfg.linear_conv_kernel_dim - 1)) + self.assertEqual( + new_rec.shape, + (B + 1, cfg.linear_num_value_heads, cfg.linear_key_head_dim, cfg.linear_value_head_dim), + ) + self.assertTrue(bool(jnp.all(jnp.isfinite(out)))) + + def test_extend_path(self): + (out, new_conv, new_rec), B, T, cfg, conv_dim = self._run_layer(is_decode=False) + self.assertEqual(out.shape, (T, cfg.hidden_size)) + self.assertEqual(out.dtype, jnp.bfloat16) + self.assertEqual(new_conv.shape, (B + 1, conv_dim, cfg.linear_conv_kernel_dim - 1)) + self.assertEqual( + new_rec.shape, + (B + 1, cfg.linear_num_value_heads, cfg.linear_key_head_dim, cfg.linear_value_head_dim), + ) + self.assertTrue(bool(jnp.all(jnp.isfinite(out)))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c522d6f851..92bfc21566 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -473,6 +473,14 @@ def run_one_file( TestFile("python/sgl_jax/test/layers/test_linear_attention.py", 1.5, runner="pytest"), TestFile("test/srt/lora/test_bgmv_backend.py", 4), TestFile("test/srt/lora/test_align_lora_accuracy.py", 3.5), + # GDN (gated DeltaNet) — CPU-only unit tests; each pins + # JAX_PLATFORMS=cpu + 8 fake devices in its header, so they run on + # any TPU runner without consuming TPU chips. + TestFile("python/sgl_jax/test/kernels/gdn/test_gated_delta.py", 1), + TestFile("python/sgl_jax/test/kernels/gdn/test_ragged_gated_delta_rule_ref.py", 1), + TestFile("python/sgl_jax/test/layers/test_gdn_backend.py", 1), + TestFile("python/sgl_jax/test/layers/test_merged_column_parallel_linear.py", 1), + TestFile("python/sgl_jax/test/layers/test_qwen3_5_gated_delta_net.py", 1), ], "unit-test-tpu-v6e-4": [ TestFile("python/sgl_jax/test/test_mesh.py", 1),