Skip to content

Commit e9043d3

Browse files
committed
feat(mem_cache): add RecurrentStatePool with DP sharding
Migrated from epic/support_kimi_linear with DP support added. Pure buffer pool for linear recurrent layers (KDA/Mamba/GDN). Key changes vs epic: - max_num_reqs → size (align with upstream sglang MambaPool) - dp_size param with slot dim sharded on P("data", ...) - total_slots = ceil_to(size+1, dp_size) for DP divisibility
1 parent 36e6347 commit e9043d3

1 file changed

Lines changed: 297 additions & 0 deletions

File tree

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
"""RecurrentStatePool -- recurrent + conv state buffer pool for linear recurrent layers.
2+
3+
Dual list containers:
4+
recurrent_buffers: list[jax.Array], length L, each [total_slots, H, D, D]
5+
conv_buffers: list[list[jax.Array]], outer L, inner 1, each [total_slots, proj_size, K-1]
6+
7+
Slot 0 reserved as dummy; valid slots start from 1 (aligned with sglang MambaPool).
8+
total_slots = ceil_to_dp(size + 1, dp_size) so the slot dim is evenly divisible
9+
by dp_size for P("data", ...) sharding.
10+
11+
Pure buffer pool: slot allocator state lives in HybridReqToTokenPool.
12+
Does NOT inherit from KVCache.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import os
18+
19+
import jax
20+
import jax.numpy as jnp
21+
from jax.sharding import Mesh, NamedSharding
22+
from jax.sharding import PartitionSpec as P
23+
from jax.tree_util import register_pytree_node_class
24+
25+
_DTYPE_MAP = {
26+
"float32": jnp.float32,
27+
"bfloat16": jnp.bfloat16,
28+
"float16": jnp.float16,
29+
}
30+
31+
32+
def _resolve_dtype(env_var: str, default):
33+
name = os.environ.get(env_var)
34+
return _DTYPE_MAP[name] if name else default
35+
36+
37+
def _ceil_to(value: int, divisor: int) -> int:
38+
return (value + divisor - 1) // divisor * divisor
39+
40+
41+
@register_pytree_node_class
42+
class RecurrentStatePool:
43+
"""Recurrent + conv state buffer pool (per-slot indexing, no slot allocator)."""
44+
45+
def __init__(
46+
self,
47+
linear_recurrent_layer_ids: list[int],
48+
size: int,
49+
num_heads: int,
50+
head_dim: int,
51+
conv_kernel_size: int,
52+
mesh: Mesh,
53+
dp_size: int = 1,
54+
recurrent_partition_axis: str = "tensor",
55+
conv_partition_axis: str = "tensor",
56+
data_partition_axis: str = "data",
57+
temporal_dtype=None,
58+
conv_dtype=None,
59+
num_k_heads: int | None = None,
60+
head_k_dim: int | None = None,
61+
):
62+
if temporal_dtype is None:
63+
temporal_dtype = _resolve_dtype("SGLANG_JAX_RECURRENT_STATE_DTYPE", jnp.float32)
64+
if conv_dtype is None:
65+
conv_dtype = _resolve_dtype("SGLANG_JAX_CONV_STATE_DTYPE", jnp.bfloat16)
66+
self.temporal_dtype = temporal_dtype
67+
self.conv_dtype = conv_dtype
68+
69+
if num_k_heads is None:
70+
num_k_heads = num_heads
71+
if head_k_dim is None:
72+
head_k_dim = head_dim
73+
74+
assert len(set(linear_recurrent_layer_ids)) == len(linear_recurrent_layer_ids), (
75+
f"linear_recurrent_layer_ids must not contain duplicates, "
76+
f"got {linear_recurrent_layer_ids}"
77+
)
78+
self.linear_recurrent_layer_ids: list[int] = list(linear_recurrent_layer_ids)
79+
self.layers_mapping: dict[int, int] = {
80+
layer_id: idx for idx, layer_id in enumerate(self.linear_recurrent_layer_ids)
81+
}
82+
self.num_linear_recurrent_layers: int = len(self.linear_recurrent_layer_ids)
83+
84+
self.size = size
85+
self.dp_size = dp_size
86+
self.num_heads = num_heads
87+
self.head_dim = head_dim
88+
self.num_k_heads = num_k_heads
89+
self.head_k_dim = head_k_dim
90+
self.conv_kernel_size = conv_kernel_size
91+
92+
proj_v = num_heads * head_dim
93+
proj_k = num_k_heads * head_k_dim
94+
self.proj_size = proj_v + 2 * proj_k
95+
96+
# total_slots: size+1 (for dummy slot 0), ceil to dp_size
97+
self.total_slots = _ceil_to(size + 1, dp_size)
98+
99+
assert size > 0, f"size must be > 0, got {size}"
100+
assert num_heads > 0
101+
assert head_dim > 0
102+
assert num_k_heads > 0
103+
assert head_k_dim > 0
104+
assert (
105+
conv_kernel_size >= 2
106+
), f"conv_kernel_size must be >= 2 (got {conv_kernel_size}); K=1 produces empty conv buffers."
107+
assert self.proj_size > 0
108+
109+
self.mesh = mesh
110+
self.recurrent_partition_axis = recurrent_partition_axis
111+
self.conv_partition_axis = conv_partition_axis
112+
self.data_partition_axis = data_partition_axis
113+
114+
recurrent_axis_size = mesh.shape[recurrent_partition_axis]
115+
conv_axis_size = mesh.shape[conv_partition_axis]
116+
assert num_heads % recurrent_axis_size == 0, (
117+
f"num_heads {num_heads} must be divisible by "
118+
f"'{recurrent_partition_axis}' size {recurrent_axis_size}"
119+
)
120+
assert num_k_heads % recurrent_axis_size == 0, (
121+
f"num_k_heads {num_k_heads} must be divisible by "
122+
f"'{recurrent_partition_axis}' size {recurrent_axis_size}"
123+
)
124+
assert self.proj_size % conv_axis_size == 0, (
125+
f"proj_size {self.proj_size} must be divisible by "
126+
f"'{conv_partition_axis}' size {conv_axis_size}"
127+
)
128+
129+
self.recurrent_sharding = NamedSharding(
130+
mesh, P(data_partition_axis, recurrent_partition_axis, None, None)
131+
)
132+
self.conv_sharding = NamedSharding(mesh, P(data_partition_axis, conv_partition_axis, None))
133+
134+
self.recurrent_buffers, self.conv_buffers = self._create_buffers()
135+
136+
def _create_buffers(self) -> tuple[list, list]:
137+
recurrent_shape = (self.total_slots, self.num_heads, self.head_dim, self.head_dim)
138+
conv_shape = (self.total_slots, self.proj_size, self.conv_kernel_size - 1)
139+
temporal_dtype = self.temporal_dtype
140+
conv_dtype = self.conv_dtype
141+
142+
with self.mesh:
143+
recurrent_buffers = []
144+
for _ in range(self.num_linear_recurrent_layers):
145+
buf = jax.jit(
146+
lambda: jnp.zeros(shape=recurrent_shape, dtype=temporal_dtype),
147+
out_shardings=self.recurrent_sharding,
148+
)()
149+
recurrent_buffers.append(buf)
150+
151+
conv_buffers = []
152+
for _ in range(self.num_linear_recurrent_layers):
153+
inner = []
154+
buf = jax.jit(
155+
lambda: jnp.zeros(shape=conv_shape, dtype=conv_dtype),
156+
out_shardings=self.conv_sharding,
157+
)()
158+
inner.append(buf)
159+
conv_buffers.append(inner)
160+
161+
return recurrent_buffers, conv_buffers
162+
163+
def clear_slot(self, idx_or_indices) -> None:
164+
"""Zero the per-slot buffers for the given slot(s). Used for clear-on-alloc."""
165+
indices = [idx_or_indices] if isinstance(idx_or_indices, int) else list(idx_or_indices)
166+
if not indices:
167+
return
168+
169+
idx_arr = jnp.asarray(indices, dtype=jnp.int32)
170+
with jax.set_mesh(self.mesh):
171+
for layer in range(self.num_linear_recurrent_layers):
172+
self.recurrent_buffers[layer] = self.recurrent_buffers[layer].at[idx_arr].set(0)
173+
for inner in range(len(self.conv_buffers[layer])):
174+
self.conv_buffers[layer][inner] = (
175+
self.conv_buffers[layer][inner].at[idx_arr].set(0)
176+
)
177+
178+
def get_linear_recurrent_layer_cache(self, layer_id: int):
179+
"""Read the per-layer view, keyed by model-global layer_id.
180+
181+
Returns (recurrent_per_layer, conv_per_layer).
182+
"""
183+
if layer_id not in self.layers_mapping:
184+
raise ValueError(
185+
f"layer_id={layer_id} is not a registered linear recurrent layer. "
186+
f"Registered: {self.linear_recurrent_layer_ids}"
187+
)
188+
idx = self.layers_mapping[layer_id]
189+
return self.recurrent_buffers[idx], self.conv_buffers[idx]
190+
191+
def replace_buffer(self, buffers) -> None:
192+
"""Update both buffer-list references after a JIT donate.
193+
194+
buffers: tuple[list[jax.Array], list[list[jax.Array]]]
195+
"""
196+
new_recurrent, new_conv = buffers
197+
198+
assert len(new_recurrent) == self.num_linear_recurrent_layers
199+
assert len(new_conv) == self.num_linear_recurrent_layers
200+
for layer in range(self.num_linear_recurrent_layers):
201+
assert len(new_conv[layer]) == len(self.conv_buffers[layer])
202+
203+
tp_degenerate = self.mesh.shape.get("tensor", 1) == 1
204+
for layer in range(self.num_linear_recurrent_layers):
205+
buf = new_recurrent[layer]
206+
if tp_degenerate and hasattr(self, "recurrent_sharding"):
207+
buf = jax.device_put(buf, self.recurrent_sharding)
208+
self.recurrent_buffers[layer] = buf
209+
210+
for layer in range(self.num_linear_recurrent_layers):
211+
for i in range(len(new_conv[layer])):
212+
buf = new_conv[layer][i]
213+
if tp_degenerate and hasattr(self, "conv_sharding"):
214+
buf = jax.device_put(buf, self.conv_sharding)
215+
self.conv_buffers[layer][i] = buf
216+
217+
def clear(self) -> None:
218+
"""Full reset: zero out every layer's recurrent + conv buffer."""
219+
for layer in range(self.num_linear_recurrent_layers):
220+
self.recurrent_buffers[layer] = jnp.zeros_like(self.recurrent_buffers[layer])
221+
for inner in range(len(self.conv_buffers[layer])):
222+
self.conv_buffers[layer][inner] = jnp.zeros_like(self.conv_buffers[layer][inner])
223+
224+
# --- pytree ---
225+
def tree_flatten(self):
226+
children = (self.recurrent_buffers, self.conv_buffers)
227+
aux = (
228+
tuple(self.linear_recurrent_layer_ids),
229+
self.size,
230+
self.dp_size,
231+
self.total_slots,
232+
self.num_heads,
233+
self.head_dim,
234+
self.num_k_heads,
235+
self.head_k_dim,
236+
self.conv_kernel_size,
237+
self.temporal_dtype,
238+
self.conv_dtype,
239+
self.mesh,
240+
self.recurrent_partition_axis,
241+
self.conv_partition_axis,
242+
self.data_partition_axis,
243+
self.recurrent_sharding,
244+
self.conv_sharding,
245+
)
246+
return children, aux
247+
248+
@classmethod
249+
def tree_unflatten(cls, aux_data, children):
250+
(
251+
linear_recurrent_layer_ids_tup,
252+
size,
253+
dp_size,
254+
total_slots,
255+
num_heads,
256+
head_dim,
257+
num_k_heads,
258+
head_k_dim,
259+
conv_kernel_size,
260+
temporal_dtype,
261+
conv_dtype,
262+
mesh,
263+
recurrent_partition_axis,
264+
conv_partition_axis,
265+
data_partition_axis,
266+
recurrent_sharding,
267+
conv_sharding,
268+
) = aux_data
269+
obj = cls.__new__(cls)
270+
obj.linear_recurrent_layer_ids = list(linear_recurrent_layer_ids_tup)
271+
obj.layers_mapping = {
272+
layer_id: idx for idx, layer_id in enumerate(obj.linear_recurrent_layer_ids)
273+
}
274+
obj.num_linear_recurrent_layers = len(obj.linear_recurrent_layer_ids)
275+
obj.size = size
276+
obj.dp_size = dp_size
277+
obj.total_slots = total_slots
278+
obj.num_heads = num_heads
279+
obj.head_dim = head_dim
280+
obj.num_k_heads = num_k_heads
281+
obj.head_k_dim = head_k_dim
282+
obj.conv_kernel_size = conv_kernel_size
283+
obj.temporal_dtype = temporal_dtype
284+
obj.conv_dtype = conv_dtype
285+
proj_v = num_heads * head_dim
286+
proj_k = num_k_heads * head_k_dim
287+
obj.proj_size = proj_v + 2 * proj_k
288+
obj.mesh = mesh
289+
obj.recurrent_partition_axis = recurrent_partition_axis
290+
obj.conv_partition_axis = conv_partition_axis
291+
obj.data_partition_axis = data_partition_axis
292+
obj.recurrent_sharding = recurrent_sharding
293+
obj.conv_sharding = conv_sharding
294+
new_recurrent, new_conv = children
295+
obj.recurrent_buffers = list(new_recurrent)
296+
obj.conv_buffers = [list(inner) for inner in new_conv]
297+
return obj

0 commit comments

Comments
 (0)