|
| 1 | +"""RecurrentStatePool -- recurrent + conv state buffer pool for linear recurrent layers. |
| 2 | +
|
| 3 | +Dual list containers: |
| 4 | + recurrent_buffers: list[jax.Array], length L, each [total_slots, H, D, D] |
| 5 | + conv_buffers: list[list[jax.Array]], outer L, inner 1, each [total_slots, proj_size, K-1] |
| 6 | +
|
| 7 | +Slot 0 reserved as dummy; valid slots start from 1 (aligned with sglang MambaPool). |
| 8 | +total_slots = ceil_to_dp(size + 1, dp_size) so the slot dim is evenly divisible |
| 9 | +by dp_size for P("data", ...) sharding. |
| 10 | +
|
| 11 | +Pure buffer pool: slot allocator state lives in HybridReqToTokenPool. |
| 12 | +Does NOT inherit from KVCache. |
| 13 | +""" |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import os |
| 18 | + |
| 19 | +import jax |
| 20 | +import jax.numpy as jnp |
| 21 | +from jax.sharding import Mesh, NamedSharding |
| 22 | +from jax.sharding import PartitionSpec as P |
| 23 | +from jax.tree_util import register_pytree_node_class |
| 24 | + |
| 25 | +_DTYPE_MAP = { |
| 26 | + "float32": jnp.float32, |
| 27 | + "bfloat16": jnp.bfloat16, |
| 28 | + "float16": jnp.float16, |
| 29 | +} |
| 30 | + |
| 31 | + |
| 32 | +def _resolve_dtype(env_var: str, default): |
| 33 | + name = os.environ.get(env_var) |
| 34 | + return _DTYPE_MAP[name] if name else default |
| 35 | + |
| 36 | + |
| 37 | +def _ceil_to(value: int, divisor: int) -> int: |
| 38 | + return (value + divisor - 1) // divisor * divisor |
| 39 | + |
| 40 | + |
| 41 | +@register_pytree_node_class |
| 42 | +class RecurrentStatePool: |
| 43 | + """Recurrent + conv state buffer pool (per-slot indexing, no slot allocator).""" |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + linear_recurrent_layer_ids: list[int], |
| 48 | + size: int, |
| 49 | + num_heads: int, |
| 50 | + head_dim: int, |
| 51 | + conv_kernel_size: int, |
| 52 | + mesh: Mesh, |
| 53 | + dp_size: int = 1, |
| 54 | + recurrent_partition_axis: str = "tensor", |
| 55 | + conv_partition_axis: str = "tensor", |
| 56 | + data_partition_axis: str = "data", |
| 57 | + temporal_dtype=None, |
| 58 | + conv_dtype=None, |
| 59 | + num_k_heads: int | None = None, |
| 60 | + head_k_dim: int | None = None, |
| 61 | + ): |
| 62 | + if temporal_dtype is None: |
| 63 | + temporal_dtype = _resolve_dtype("SGLANG_JAX_RECURRENT_STATE_DTYPE", jnp.float32) |
| 64 | + if conv_dtype is None: |
| 65 | + conv_dtype = _resolve_dtype("SGLANG_JAX_CONV_STATE_DTYPE", jnp.bfloat16) |
| 66 | + self.temporal_dtype = temporal_dtype |
| 67 | + self.conv_dtype = conv_dtype |
| 68 | + |
| 69 | + if num_k_heads is None: |
| 70 | + num_k_heads = num_heads |
| 71 | + if head_k_dim is None: |
| 72 | + head_k_dim = head_dim |
| 73 | + |
| 74 | + assert len(set(linear_recurrent_layer_ids)) == len(linear_recurrent_layer_ids), ( |
| 75 | + f"linear_recurrent_layer_ids must not contain duplicates, " |
| 76 | + f"got {linear_recurrent_layer_ids}" |
| 77 | + ) |
| 78 | + self.linear_recurrent_layer_ids: list[int] = list(linear_recurrent_layer_ids) |
| 79 | + self.layers_mapping: dict[int, int] = { |
| 80 | + layer_id: idx for idx, layer_id in enumerate(self.linear_recurrent_layer_ids) |
| 81 | + } |
| 82 | + self.num_linear_recurrent_layers: int = len(self.linear_recurrent_layer_ids) |
| 83 | + |
| 84 | + self.size = size |
| 85 | + self.dp_size = dp_size |
| 86 | + self.num_heads = num_heads |
| 87 | + self.head_dim = head_dim |
| 88 | + self.num_k_heads = num_k_heads |
| 89 | + self.head_k_dim = head_k_dim |
| 90 | + self.conv_kernel_size = conv_kernel_size |
| 91 | + |
| 92 | + proj_v = num_heads * head_dim |
| 93 | + proj_k = num_k_heads * head_k_dim |
| 94 | + self.proj_size = proj_v + 2 * proj_k |
| 95 | + |
| 96 | + # total_slots: size+1 (for dummy slot 0), ceil to dp_size |
| 97 | + self.total_slots = _ceil_to(size + 1, dp_size) |
| 98 | + |
| 99 | + assert size > 0, f"size must be > 0, got {size}" |
| 100 | + assert num_heads > 0 |
| 101 | + assert head_dim > 0 |
| 102 | + assert num_k_heads > 0 |
| 103 | + assert head_k_dim > 0 |
| 104 | + assert ( |
| 105 | + conv_kernel_size >= 2 |
| 106 | + ), f"conv_kernel_size must be >= 2 (got {conv_kernel_size}); K=1 produces empty conv buffers." |
| 107 | + assert self.proj_size > 0 |
| 108 | + |
| 109 | + self.mesh = mesh |
| 110 | + self.recurrent_partition_axis = recurrent_partition_axis |
| 111 | + self.conv_partition_axis = conv_partition_axis |
| 112 | + self.data_partition_axis = data_partition_axis |
| 113 | + |
| 114 | + recurrent_axis_size = mesh.shape[recurrent_partition_axis] |
| 115 | + conv_axis_size = mesh.shape[conv_partition_axis] |
| 116 | + assert num_heads % recurrent_axis_size == 0, ( |
| 117 | + f"num_heads {num_heads} must be divisible by " |
| 118 | + f"'{recurrent_partition_axis}' size {recurrent_axis_size}" |
| 119 | + ) |
| 120 | + assert num_k_heads % recurrent_axis_size == 0, ( |
| 121 | + f"num_k_heads {num_k_heads} must be divisible by " |
| 122 | + f"'{recurrent_partition_axis}' size {recurrent_axis_size}" |
| 123 | + ) |
| 124 | + assert self.proj_size % conv_axis_size == 0, ( |
| 125 | + f"proj_size {self.proj_size} must be divisible by " |
| 126 | + f"'{conv_partition_axis}' size {conv_axis_size}" |
| 127 | + ) |
| 128 | + |
| 129 | + self.recurrent_sharding = NamedSharding( |
| 130 | + mesh, P(data_partition_axis, recurrent_partition_axis, None, None) |
| 131 | + ) |
| 132 | + self.conv_sharding = NamedSharding(mesh, P(data_partition_axis, conv_partition_axis, None)) |
| 133 | + |
| 134 | + self.recurrent_buffers, self.conv_buffers = self._create_buffers() |
| 135 | + |
| 136 | + def _create_buffers(self) -> tuple[list, list]: |
| 137 | + recurrent_shape = (self.total_slots, self.num_heads, self.head_dim, self.head_dim) |
| 138 | + conv_shape = (self.total_slots, self.proj_size, self.conv_kernel_size - 1) |
| 139 | + temporal_dtype = self.temporal_dtype |
| 140 | + conv_dtype = self.conv_dtype |
| 141 | + |
| 142 | + with self.mesh: |
| 143 | + recurrent_buffers = [] |
| 144 | + for _ in range(self.num_linear_recurrent_layers): |
| 145 | + buf = jax.jit( |
| 146 | + lambda: jnp.zeros(shape=recurrent_shape, dtype=temporal_dtype), |
| 147 | + out_shardings=self.recurrent_sharding, |
| 148 | + )() |
| 149 | + recurrent_buffers.append(buf) |
| 150 | + |
| 151 | + conv_buffers = [] |
| 152 | + for _ in range(self.num_linear_recurrent_layers): |
| 153 | + inner = [] |
| 154 | + buf = jax.jit( |
| 155 | + lambda: jnp.zeros(shape=conv_shape, dtype=conv_dtype), |
| 156 | + out_shardings=self.conv_sharding, |
| 157 | + )() |
| 158 | + inner.append(buf) |
| 159 | + conv_buffers.append(inner) |
| 160 | + |
| 161 | + return recurrent_buffers, conv_buffers |
| 162 | + |
| 163 | + def clear_slot(self, idx_or_indices) -> None: |
| 164 | + """Zero the per-slot buffers for the given slot(s). Used for clear-on-alloc.""" |
| 165 | + indices = [idx_or_indices] if isinstance(idx_or_indices, int) else list(idx_or_indices) |
| 166 | + if not indices: |
| 167 | + return |
| 168 | + |
| 169 | + idx_arr = jnp.asarray(indices, dtype=jnp.int32) |
| 170 | + with jax.set_mesh(self.mesh): |
| 171 | + for layer in range(self.num_linear_recurrent_layers): |
| 172 | + self.recurrent_buffers[layer] = self.recurrent_buffers[layer].at[idx_arr].set(0) |
| 173 | + for inner in range(len(self.conv_buffers[layer])): |
| 174 | + self.conv_buffers[layer][inner] = ( |
| 175 | + self.conv_buffers[layer][inner].at[idx_arr].set(0) |
| 176 | + ) |
| 177 | + |
| 178 | + def get_linear_recurrent_layer_cache(self, layer_id: int): |
| 179 | + """Read the per-layer view, keyed by model-global layer_id. |
| 180 | +
|
| 181 | + Returns (recurrent_per_layer, conv_per_layer). |
| 182 | + """ |
| 183 | + if layer_id not in self.layers_mapping: |
| 184 | + raise ValueError( |
| 185 | + f"layer_id={layer_id} is not a registered linear recurrent layer. " |
| 186 | + f"Registered: {self.linear_recurrent_layer_ids}" |
| 187 | + ) |
| 188 | + idx = self.layers_mapping[layer_id] |
| 189 | + return self.recurrent_buffers[idx], self.conv_buffers[idx] |
| 190 | + |
| 191 | + def replace_buffer(self, buffers) -> None: |
| 192 | + """Update both buffer-list references after a JIT donate. |
| 193 | +
|
| 194 | + buffers: tuple[list[jax.Array], list[list[jax.Array]]] |
| 195 | + """ |
| 196 | + new_recurrent, new_conv = buffers |
| 197 | + |
| 198 | + assert len(new_recurrent) == self.num_linear_recurrent_layers |
| 199 | + assert len(new_conv) == self.num_linear_recurrent_layers |
| 200 | + for layer in range(self.num_linear_recurrent_layers): |
| 201 | + assert len(new_conv[layer]) == len(self.conv_buffers[layer]) |
| 202 | + |
| 203 | + tp_degenerate = self.mesh.shape.get("tensor", 1) == 1 |
| 204 | + for layer in range(self.num_linear_recurrent_layers): |
| 205 | + buf = new_recurrent[layer] |
| 206 | + if tp_degenerate and hasattr(self, "recurrent_sharding"): |
| 207 | + buf = jax.device_put(buf, self.recurrent_sharding) |
| 208 | + self.recurrent_buffers[layer] = buf |
| 209 | + |
| 210 | + for layer in range(self.num_linear_recurrent_layers): |
| 211 | + for i in range(len(new_conv[layer])): |
| 212 | + buf = new_conv[layer][i] |
| 213 | + if tp_degenerate and hasattr(self, "conv_sharding"): |
| 214 | + buf = jax.device_put(buf, self.conv_sharding) |
| 215 | + self.conv_buffers[layer][i] = buf |
| 216 | + |
| 217 | + def clear(self) -> None: |
| 218 | + """Full reset: zero out every layer's recurrent + conv buffer.""" |
| 219 | + for layer in range(self.num_linear_recurrent_layers): |
| 220 | + self.recurrent_buffers[layer] = jnp.zeros_like(self.recurrent_buffers[layer]) |
| 221 | + for inner in range(len(self.conv_buffers[layer])): |
| 222 | + self.conv_buffers[layer][inner] = jnp.zeros_like(self.conv_buffers[layer][inner]) |
| 223 | + |
| 224 | + # --- pytree --- |
| 225 | + def tree_flatten(self): |
| 226 | + children = (self.recurrent_buffers, self.conv_buffers) |
| 227 | + aux = ( |
| 228 | + tuple(self.linear_recurrent_layer_ids), |
| 229 | + self.size, |
| 230 | + self.dp_size, |
| 231 | + self.total_slots, |
| 232 | + self.num_heads, |
| 233 | + self.head_dim, |
| 234 | + self.num_k_heads, |
| 235 | + self.head_k_dim, |
| 236 | + self.conv_kernel_size, |
| 237 | + self.temporal_dtype, |
| 238 | + self.conv_dtype, |
| 239 | + self.mesh, |
| 240 | + self.recurrent_partition_axis, |
| 241 | + self.conv_partition_axis, |
| 242 | + self.data_partition_axis, |
| 243 | + self.recurrent_sharding, |
| 244 | + self.conv_sharding, |
| 245 | + ) |
| 246 | + return children, aux |
| 247 | + |
| 248 | + @classmethod |
| 249 | + def tree_unflatten(cls, aux_data, children): |
| 250 | + ( |
| 251 | + linear_recurrent_layer_ids_tup, |
| 252 | + size, |
| 253 | + dp_size, |
| 254 | + total_slots, |
| 255 | + num_heads, |
| 256 | + head_dim, |
| 257 | + num_k_heads, |
| 258 | + head_k_dim, |
| 259 | + conv_kernel_size, |
| 260 | + temporal_dtype, |
| 261 | + conv_dtype, |
| 262 | + mesh, |
| 263 | + recurrent_partition_axis, |
| 264 | + conv_partition_axis, |
| 265 | + data_partition_axis, |
| 266 | + recurrent_sharding, |
| 267 | + conv_sharding, |
| 268 | + ) = aux_data |
| 269 | + obj = cls.__new__(cls) |
| 270 | + obj.linear_recurrent_layer_ids = list(linear_recurrent_layer_ids_tup) |
| 271 | + obj.layers_mapping = { |
| 272 | + layer_id: idx for idx, layer_id in enumerate(obj.linear_recurrent_layer_ids) |
| 273 | + } |
| 274 | + obj.num_linear_recurrent_layers = len(obj.linear_recurrent_layer_ids) |
| 275 | + obj.size = size |
| 276 | + obj.dp_size = dp_size |
| 277 | + obj.total_slots = total_slots |
| 278 | + obj.num_heads = num_heads |
| 279 | + obj.head_dim = head_dim |
| 280 | + obj.num_k_heads = num_k_heads |
| 281 | + obj.head_k_dim = head_k_dim |
| 282 | + obj.conv_kernel_size = conv_kernel_size |
| 283 | + obj.temporal_dtype = temporal_dtype |
| 284 | + obj.conv_dtype = conv_dtype |
| 285 | + proj_v = num_heads * head_dim |
| 286 | + proj_k = num_k_heads * head_k_dim |
| 287 | + obj.proj_size = proj_v + 2 * proj_k |
| 288 | + obj.mesh = mesh |
| 289 | + obj.recurrent_partition_axis = recurrent_partition_axis |
| 290 | + obj.conv_partition_axis = conv_partition_axis |
| 291 | + obj.data_partition_axis = data_partition_axis |
| 292 | + obj.recurrent_sharding = recurrent_sharding |
| 293 | + obj.conv_sharding = conv_sharding |
| 294 | + new_recurrent, new_conv = children |
| 295 | + obj.recurrent_buffers = list(new_recurrent) |
| 296 | + obj.conv_buffers = [list(inner) for inner in new_conv] |
| 297 | + return obj |
0 commit comments