Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions python/sgl_jax/srt/mem_cache/recurrent_state_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""RecurrentStatePool -- buffer pool for linear recurrent layers (KDA/Mamba/GDN)."""

from __future__ import annotations

import os

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.tree_util import register_pytree_node_class

_DTYPE_MAP = {
"float32": jnp.float32,
"bfloat16": jnp.bfloat16,
"float16": jnp.float16,
}


def _resolve_dtype(env_var: str, default):
name = os.environ.get(env_var)
return _DTYPE_MAP[name] if name else default


def _ceil_to(value: int, divisor: int) -> int:
return (value + divisor - 1) // divisor * divisor


@register_pytree_node_class
class RecurrentStatePool:

def __init__(
self,
linear_recurrent_layer_ids: list[int],
size: int,
num_heads: int,
head_dim: int,
conv_kernel_size: int,
mesh: Mesh,
dp_size: int = 1,
recurrent_partition_axis: str = "tensor",
conv_partition_axis: str = "tensor",
data_partition_axis: str = "data",
temporal_dtype=None,
conv_dtype=None,
num_k_heads: int | None = None,
head_k_dim: int | None = None,
):
if temporal_dtype is None:
temporal_dtype = _resolve_dtype("SGLANG_JAX_RECURRENT_STATE_DTYPE", jnp.float32)
if conv_dtype is None:
conv_dtype = _resolve_dtype("SGLANG_JAX_CONV_STATE_DTYPE", jnp.bfloat16)
self.temporal_dtype = temporal_dtype
self.conv_dtype = conv_dtype

if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim

assert len(set(linear_recurrent_layer_ids)) == len(linear_recurrent_layer_ids), (
f"linear_recurrent_layer_ids must not contain duplicates, "
f"got {linear_recurrent_layer_ids}"
)
self.linear_recurrent_layer_ids: list[int] = list(linear_recurrent_layer_ids)
self.layers_mapping: dict[int, int] = {
layer_id: idx for idx, layer_id in enumerate(self.linear_recurrent_layer_ids)
}
self.num_linear_recurrent_layers: int = len(self.linear_recurrent_layer_ids)

self.size = size
self.dp_size = dp_size
self.num_heads = num_heads
self.head_dim = head_dim
self.num_k_heads = num_k_heads
self.head_k_dim = head_k_dim
self.conv_kernel_size = conv_kernel_size

proj_v = num_heads * head_dim
proj_k = num_k_heads * head_k_dim
self.proj_size = proj_v + 2 * proj_k

# total_slots: size+1 (for dummy slot 0), ceil to dp_size
self.total_slots = _ceil_to(size + 1, dp_size)

self.mesh = mesh
self.recurrent_partition_axis = recurrent_partition_axis
self.conv_partition_axis = conv_partition_axis
self.data_partition_axis = data_partition_axis

recurrent_axis_size = mesh.shape[recurrent_partition_axis]
conv_axis_size = mesh.shape[conv_partition_axis]
assert num_heads % recurrent_axis_size == 0, (
f"num_heads {num_heads} must be divisible by "
f"'{recurrent_partition_axis}' size {recurrent_axis_size}"
)
assert num_k_heads % recurrent_axis_size == 0, (
f"num_k_heads {num_k_heads} must be divisible by "
f"'{recurrent_partition_axis}' size {recurrent_axis_size}"
)
assert self.proj_size % conv_axis_size == 0, (
f"proj_size {self.proj_size} must be divisible by "
f"'{conv_partition_axis}' size {conv_axis_size}"
)

self.recurrent_sharding = NamedSharding(
mesh, P(data_partition_axis, recurrent_partition_axis, None, None)
)
self.conv_sharding = NamedSharding(mesh, P(data_partition_axis, conv_partition_axis, None))

self.recurrent_buffers, self.conv_buffers = self._create_buffers()

def _create_buffers(self) -> tuple[list, list]:
recurrent_shape = (self.total_slots, self.num_heads, self.head_dim, self.head_dim)
conv_shape = (self.total_slots, self.proj_size, self.conv_kernel_size - 1)
temporal_dtype = self.temporal_dtype
conv_dtype = self.conv_dtype

with self.mesh:
recurrent_buffers = []
for _ in range(self.num_linear_recurrent_layers):
buf = jax.jit(
lambda: jnp.zeros(shape=recurrent_shape, dtype=temporal_dtype),
out_shardings=self.recurrent_sharding,
)()
recurrent_buffers.append(buf)

conv_buffers = []
for _ in range(self.num_linear_recurrent_layers):
inner = []
buf = jax.jit(
lambda: jnp.zeros(shape=conv_shape, dtype=conv_dtype),
out_shardings=self.conv_sharding,
)()
inner.append(buf)
conv_buffers.append(inner)

return recurrent_buffers, conv_buffers

def clear_slot(self, idx_or_indices) -> None:
indices = [idx_or_indices] if isinstance(idx_or_indices, int) else list(idx_or_indices)
if not indices:
return

idx_arr = jnp.asarray(indices, dtype=jnp.int32)
with jax.set_mesh(self.mesh):
for layer in range(self.num_linear_recurrent_layers):
self.recurrent_buffers[layer] = self.recurrent_buffers[layer].at[idx_arr].set(0)
for inner in range(len(self.conv_buffers[layer])):
self.conv_buffers[layer][inner] = (
self.conv_buffers[layer][inner].at[idx_arr].set(0)
)

def get_linear_recurrent_layer_cache(self, layer_id: int):
if layer_id not in self.layers_mapping:
raise ValueError(
f"layer_id={layer_id} is not a registered linear recurrent layer. "
f"Registered: {self.linear_recurrent_layer_ids}"
)
idx = self.layers_mapping[layer_id]
return self.recurrent_buffers[idx], self.conv_buffers[idx]

def replace_buffer(self, buffers) -> None:
new_recurrent, new_conv = buffers

assert len(new_recurrent) == self.num_linear_recurrent_layers
assert len(new_conv) == self.num_linear_recurrent_layers

# tp_size==1 sharding fix: see MHATokenToKVPool.replace_buffer
tp_degenerate = self.mesh.shape.get("tensor", 1) == 1
for layer in range(self.num_linear_recurrent_layers):
buf = new_recurrent[layer]
if tp_degenerate:
buf = jax.device_put(buf, self.recurrent_sharding)
self.recurrent_buffers[layer] = buf

assert len(new_conv[layer]) == len(self.conv_buffers[layer])
for i in range(len(new_conv[layer])):
cbuf = new_conv[layer][i]
if tp_degenerate:
cbuf = jax.device_put(cbuf, self.conv_sharding)
self.conv_buffers[layer][i] = cbuf

def clear(self) -> None:
for layer in range(self.num_linear_recurrent_layers):
self.recurrent_buffers[layer] = jnp.zeros_like(self.recurrent_buffers[layer])
for inner in range(len(self.conv_buffers[layer])):
self.conv_buffers[layer][inner] = jnp.zeros_like(self.conv_buffers[layer][inner])

# --- pytree ---
def tree_flatten(self):
children = (self.recurrent_buffers, self.conv_buffers)
aux = (
tuple(self.linear_recurrent_layer_ids),
self.size,
self.dp_size,
self.total_slots,
self.num_heads,
self.head_dim,
self.num_k_heads,
self.head_k_dim,
self.conv_kernel_size,
self.temporal_dtype,
self.conv_dtype,
self.mesh,
self.recurrent_partition_axis,
self.conv_partition_axis,
self.data_partition_axis,
self.recurrent_sharding,
self.conv_sharding,
)
return children, aux

@classmethod
def tree_unflatten(cls, aux_data, children):
(
linear_recurrent_layer_ids_tup,
size,
dp_size,
total_slots,
num_heads,
head_dim,
num_k_heads,
head_k_dim,
conv_kernel_size,
temporal_dtype,
conv_dtype,
mesh,
recurrent_partition_axis,
conv_partition_axis,
data_partition_axis,
recurrent_sharding,
conv_sharding,
) = aux_data
obj = cls.__new__(cls)
obj.linear_recurrent_layer_ids = list(linear_recurrent_layer_ids_tup)
obj.layers_mapping = {
layer_id: idx for idx, layer_id in enumerate(obj.linear_recurrent_layer_ids)
}
obj.num_linear_recurrent_layers = len(obj.linear_recurrent_layer_ids)
obj.size = size
obj.dp_size = dp_size
obj.total_slots = total_slots
obj.num_heads = num_heads
obj.head_dim = head_dim
obj.num_k_heads = num_k_heads
obj.head_k_dim = head_k_dim
obj.conv_kernel_size = conv_kernel_size
obj.temporal_dtype = temporal_dtype
obj.conv_dtype = conv_dtype
proj_v = num_heads * head_dim
proj_k = num_k_heads * head_k_dim
obj.proj_size = proj_v + 2 * proj_k
obj.mesh = mesh
obj.recurrent_partition_axis = recurrent_partition_axis
obj.conv_partition_axis = conv_partition_axis
obj.data_partition_axis = data_partition_axis
obj.recurrent_sharding = recurrent_sharding
obj.conv_sharding = conv_sharding
new_recurrent, new_conv = children
obj.recurrent_buffers = list(new_recurrent)
obj.conv_buffers = [list(inner) for inner in new_conv]
return obj
Loading