Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3e54c97
refactor(gla): extract LightningAttnBackend and reorganize test suite
cjx0709 May 8, 2026
ddc3f8a
refactor(gla): vendor pallas kernel with internal padding and add Rad…
cjx0709 May 8, 2026
d55d29b
test(gla): fix test suite for refactored backend architecture
cjx0709 May 8, 2026
e65cdfb
feat(gla): add Data Parallelism support and comprehensive test suite
cjx0709 May 8, 2026
2219796
refact dp for lightning_backend
cjx0709 May 9, 2026
ab1e804
fix test case fail with dp refact
cjx0709 May 9, 2026
cda0d80
refact lightning_backend
cjx0709 May 9, 2026
3d41990
suit to rebase code
cjx0709 May 9, 2026
02481bc
adjust the mock states mempool
cjx0709 May 9, 2026
7214199
update
cjx0709 May 9, 2026
80aadf0
algin to sglang GPU for datatype cast and test tol
cjx0709 May 9, 2026
8678d79
Gla ssm state fix (#1045)
JamesBrianD May 9, 2026
ef35d5e
fix(gla): satisfy lint
JamesBrianD May 9, 2026
cc86c08
test(gla): move native reference and drop legacy gla test (#1048)
JamesBrianD May 9, 2026
6e25ba5
test(gla): remove native reference wrapper
JamesBrianD May 9, 2026
b49f3ff
test(gla): move lightning backend test under layers
JamesBrianD May 9, 2026
152d019
Add BailingMoe linear hybrid model
JamesBrianD May 10, 2026
9e8ff94
Fold Bailing hybrid state params into config
JamesBrianD May 10, 2026
1bfa7b5
Fix Bailing group RMSNorm sharding
JamesBrianD May 10, 2026
082ce4f
fix(mla): account for bf16 packing alignment in KV cache cell size es…
cjx0709 May 11, 2026
a1f702a
fix(gla): zero recurrent state for new requests in extend mode
cjx0709 May 11, 2026
774bda4
fix(group_rmsnorm): shard along num_groups instead of group_size
cjx0709 May 11, 2026
b0a586a
fix kernel
cjx0709 May 11, 2026
b1e1728
add native kernel: PYTHONPATH=/sglang-jax/python \
cjx0709 May 12, 2026
a42fa38
fix: native gla (#1067)
JamesBrianD May 12, 2026
5d8745d
fix(moe): hardcode renormalize=True for BailingMoE TopK
JamesBrianD May 12, 2026
8088ce5
debug: log extend shapes before shard_map
JamesBrianD May 12, 2026
2a86c13
fix(precompile): provide recurrent_indices and has_initial_state in d…
JamesBrianD May 12, 2026
a716b6d
cleanup: remove debug shape logging from lightning_backend
JamesBrianD May 12, 2026
7377cf2
add states mask (#1070)
cjx0709 May 13, 2026
61fd82a
fix(precompile): only set recurrent fields for hybrid models (#1071)
JamesBrianD May 13, 2026
942c6f1
fix
cjx0709 May 13, 2026
7c8f4e4
open overlap for linear attention
cjx0709 May 13, 2026
2940eb9
fix: load fused shared experts for bailing linear moe
JamesBrianD May 13, 2026
fd2a60c
fix(gla): mask v_decay_exp to avoid NaN at padded chunk tail (#1073)
Rodrian7 May 14, 2026
93bb96f
test: add aime26 / csimpleqa evals to test/srt/eval
cjx0709 May 14, 2026
0b332f9
Merge origin/main into cjx/gla-kernel
JamesBrianD May 14, 2026
c1d07c4
drop getattr fallback for runner.lightning_config
JamesBrianD May 14, 2026
8053bc4
chore(profiling): add named_scope to GLA/KDA/short-conv prefill+decode
JamesBrianD May 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 231 additions & 0 deletions python/sgl_jax/srt/configs/bailing_hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from __future__ import annotations

from typing import Any

from transformers import PretrainedConfig


class BailingHybridConfig(PretrainedConfig):
"""Minimal Bailing hybrid config for Ling/Ring 2.5 linear-attention models."""

model_type = "bailing_hybrid"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size: int = 157184,
hidden_size: int = 2048,
intermediate_size: int = 5120,
num_hidden_layers: int = 20,
num_attention_heads: int = 16,
num_key_value_heads: int = 4,
hidden_act: str = "silu",
use_qkv_bias: bool = False,
use_bias: bool = False,
rms_norm_eps: float = 1e-6,
tie_word_embeddings: bool = False,
max_position_embeddings: int = 32768,
rope_theta: float = 600000.0,
rope_scaling: dict[str, Any] | None = None,
pad_token_id: int = 156892,
eos_token_id: int = 156892,
num_experts: int = 256,
num_shared_experts: int = 1,
num_experts_per_tok: int = 8,
n_group: int = 8,
topk_group: int = 4,
moe_intermediate_size: int = 512,
first_k_dense_replace: int = 1,
head_dim: int | None = 128,
use_qk_norm: bool = True,
moe_router_enable_expert_bias: bool = True,
norm_topk_prob: bool = False,
routed_scaling_factor: float = 1.0,
score_function: str = "sigmoid",
router_dtype: str | None = None,
layer_group_size: int = 1,
layers_block_type: list[str] | None = None,
group_norm_size: int = 1,
linear_silu: bool = False,
use_linear_silu: bool | None = None,
linear_rope: bool = True,
full_attention_type: str = "mla",
kv_lora_rank: int = 512,
q_lora_rank: int | None = None,
qk_rope_head_dim: int = 64,
qk_nope_head_dim: int = 128,
v_head_dim: int = 128,
rope_interleave: bool = True,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.use_qkv_bias = use_qkv_bias
self.use_bias = use_bias
self.rms_norm_eps = rms_norm_eps
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.head_dim = head_dim or hidden_size // num_attention_heads
self.use_qk_norm = use_qk_norm

self.num_experts = num_experts
self.num_shared_experts = num_shared_experts
self.num_experts_per_tok = num_experts_per_tok
self.n_group = n_group
self.topk_group = topk_group
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
self.norm_topk_prob = norm_topk_prob
self.routed_scaling_factor = routed_scaling_factor
self.score_function = score_function
self.router_dtype = router_dtype

self.layer_group_size = layer_group_size
self._layers_block_type = list(layers_block_type) if layers_block_type is not None else None
self.group_norm_size = group_norm_size
self.linear_silu = linear_silu if use_linear_silu is None else use_linear_silu
self.use_linear_silu = self.linear_silu
self.linear_rope = linear_rope
self.num_linear_key_value_heads = num_attention_heads
self.full_attention_type = full_attention_type

self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.rope_interleave = rope_interleave

super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

@property
def layers_block_type(self) -> list[str]:
if self._layers_block_type is not None:
return self._layers_block_type
if self.layer_group_size <= 0:
raise ValueError(f"layer_group_size must be positive, got {self.layer_group_size}")
return [
"attention" if (layer_id + 1) % self.layer_group_size == 0 else "linear_attention"
for layer_id in range(self.num_hidden_layers)
]

@layers_block_type.setter
def layers_block_type(self, value: list[str] | None) -> None:
self._layers_block_type = list(value) if value is not None else None

@property
def linear_layer_ids(self) -> list[int]:
return [
i
for i, block_type in enumerate(self.layers_block_type)
if str(block_type).lower() == "linear_attention"
]

@property
def full_attention_layer_ids(self) -> list[int]:
return [
i
for i, block_type in enumerate(self.layers_block_type)
if str(block_type).lower() in {"attention", "full_attention"}
]

@property
def linear_state_params(self):
from sgl_jax.srt.mem_cache.recurrent_state_pool import (
LinearRecurrentStateParams,
recurrent_state_dtype,
)

return LinearRecurrentStateParams(
layers=self.linear_layer_ids,
num_heads=self.num_linear_key_value_heads,
head_dim=self.head_dim,
conv_kernel_size=1,
dtype=recurrent_state_dtype(),
)

@property
def linear_attn_config(self) -> dict[str, Any]:
return {
"kda_layers": self.linear_layer_ids,
"num_heads": self.num_attention_heads,
"head_dim": self.head_dim,
"short_conv_kernel_size": 1,
}


def get_bailing_hybrid_config(hf_config: Any) -> BailingHybridConfig | None:
if not _is_bailing_hybrid_config(hf_config):
return None

num_hidden_layers = int(hf_config.num_hidden_layers)
linear_layer_ids, _ = _get_layer_ids(hf_config, num_hidden_layers)
if not linear_layer_ids:
return None

if isinstance(hf_config, BailingHybridConfig):
return hf_config

config_kwargs = hf_config.to_dict() if hasattr(hf_config, "to_dict") else dict(vars(hf_config))
config_kwargs["layers_block_type"] = list(hf_config.layers_block_type)
return BailingHybridConfig(**config_kwargs)


def _is_bailing_hybrid_config(hf_config: Any) -> bool:
if getattr(hf_config, "model_type", None) == "bailing_hybrid":
return True
architectures = getattr(hf_config, "architectures", None) or []
return any(str(arch) == "BailingMoeV2_5ForCausalLM" for arch in architectures)


def _get_layer_ids(hf_config: Any, num_hidden_layers: int) -> tuple[list[int], list[int]]:
layers_block_type = getattr(hf_config, "layers_block_type", None)
if layers_block_type is not None:
if len(layers_block_type) != num_hidden_layers:
raise ValueError(
f"layers_block_type length ({len(layers_block_type)}) must match "
f"num_hidden_layers ({num_hidden_layers})"
)
linear_layer_ids = []
full_attention_layer_ids = []
for layer_id, block_type in enumerate(layers_block_type):
normalized = str(block_type).lower()
if normalized == "linear_attention":
linear_layer_ids.append(layer_id)
elif normalized in {"attention", "full_attention"}:
full_attention_layer_ids.append(layer_id)
else:
raise ValueError(f"Unsupported Bailing hybrid layer block type: {block_type}")
return linear_layer_ids, full_attention_layer_ids

layer_group_size = int(hf_config.layer_group_size)
if layer_group_size <= 0:
raise ValueError(f"layer_group_size must be positive, got {layer_group_size}")

linear_layer_ids = []
full_attention_layer_ids = []
for layer_id in range(num_hidden_layers):
if (layer_id + 1) % layer_group_size == 0:
full_attention_layer_ids.append(layer_id)
else:
linear_layer_ids.append(layer_id)
return linear_layer_ids, full_attention_layer_ids


__all__ = [
"BailingHybridConfig",
"get_bailing_hybrid_config",
]
2 changes: 2 additions & 0 deletions python/sgl_jax/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from sgl_jax.srt.configs.bailing_hybrid import BailingHybridConfig
from sgl_jax.srt.configs.kimi_linear import KimiLinearConfig
from sgl_jax.srt.managers.tiktoken_tokenizer import TiktokenTokenizer
from sgl_jax.srt.utils.common_utils import is_remote_url, lru_cache_frozenset

_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
cls.model_type: cls
for cls in [
BailingHybridConfig,
KimiLinearConfig,
]
}
Expand Down
125 changes: 125 additions & 0 deletions python/sgl_jax/srt/kernels/simple_gla/native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Native JAX reference implementation of GLA (Gated Linear Attention).

This module provides pure JAX implementations of GLA decode and prefill
operations. These implementations use jnp.einsum and jax.lax.scan without
Pallas kernels, matching the same dtype as the kernel under test.
"""

import jax
import jax.numpy as jnp


def naive_gla_decode(
q: jax.Array,
k: jax.Array,
v: jax.Array,
g_gamma: jax.Array,
h0: jax.Array,
scale: float | None = None,
) -> tuple[jax.Array, jax.Array]:
"""Naive GLA decode using jnp.einsum.

Args:
q: Query tensor [B, 1, H, K]
k: Key tensor [B, 1, H, K]
v: Value tensor [B, 1, H, K]
g_gamma: Gate decay per head [H], negative values (e.g., ALiBi slopes)
h0: Initial state [B, H, K, K]
scale: Optional output scaling factor

Returns:
output: [B, 1, H, K]
h1: Updated state [B, H, K, K]
"""
B, T, H, K = q.shape
assert T == 1, f"Decode expects T=1, got {T}"

q_t = q[:, 0].astype(jnp.float32)
k_t = k[:, 0].astype(jnp.float32)
v_t = v[:, 0].astype(jnp.float32)
g_gamma = g_gamma.astype(jnp.float32)
h0 = h0.astype(jnp.float32)

if scale is None:
scale = K**-0.5

decay = jnp.exp(g_gamma)[None, :, None, None]
kv = jnp.einsum("bhk,bhv->bhkv", k_t, v_t)
h1 = decay * h0 + kv
o = jnp.einsum("bhk,bhkv->bhv", q_t, h1)
o = o * scale

output = o[:, None, :, :]

return output, h1


def naive_gla_prefill(
q: jax.Array,
k: jax.Array,
v: jax.Array,
g_gamma: jax.Array,
h0: jax.Array,
cu_seqlens: jax.Array,
scale: float | None = None,
) -> tuple[jax.Array, jax.Array]:
"""Naive GLA prefill using per-request scan + jnp.einsum.

Args:
q: Query tensor [1, T_total, H, K] (varlen packed)
k: Key tensor [1, T_total, H, K] (varlen packed)
v: Value tensor [1, T_total, H, K] (varlen packed)
g_gamma: Gate decay per head [H], negative values
h0: Initial state per request [B, H, K, K]
cu_seqlens: Cumulative sequence lengths [B+1], e.g., [0, 128, 384] for 2 requests
scale: Optional output scaling factor

Returns:
output: [1, T_total, H, K]
h_final: Final state per request [B, H, K, K]
"""
assert q.shape[0] == 1, f"Prefill expects batch=1 (varlen), got {q.shape[0]}"

q = q[0].astype(jnp.float32)
k = k[0].astype(jnp.float32)
v = v[0].astype(jnp.float32)
g_gamma = g_gamma.astype(jnp.float32)
h0 = h0.astype(jnp.float32)

T = q.shape[0]
_, K = q.shape[1], q.shape[2]

if scale is None:
scale = K**-0.5

token_idx = jnp.arange(T, dtype=cu_seqlens.dtype)
seq_ids = jnp.searchsorted(cu_seqlens[1:], token_idx, side="right")
reset_mask = token_idx == cu_seqlens[:-1][seq_ids]
decay = jnp.exp(g_gamma)

def scan_fn(carry, inputs):
h_prev, final_states = carry
seq_id, do_reset, q_t, k_t, v_t = inputs
h = jnp.where(do_reset, h0[seq_id], h_prev)
kv = jnp.einsum("hk,hv->hkv", k_t, v_t)
h = decay[:, None, None] * h + kv
o_t = jnp.einsum("hk,hkv->hv", q_t, h)
final_states = final_states.at[seq_id].set(h)
return (h, final_states), o_t

init_carry = (
jnp.zeros_like(h0[0]),
h0,
)
(_, h_final), output = jax.lax.scan(
scan_fn,
init_carry,
(seq_ids, reset_mask, q, k, v),
)

output = output * scale

return output[None, :, :, :], h_final


__all__ = ["naive_gla_decode", "naive_gla_prefill"]
Loading
Loading