Skip to content

Commit 7d50687

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Add support for aligning 1D KV biases in sglang_jax.
PiperOrigin-RevId: 876482529
1 parent df627a6 commit 7d50687

File tree

3 files changed

+80
-6
lines changed

3 files changed

+80
-6
lines changed

tests/generate/utils_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,29 @@ def test_transfer_state_directly_scanned_layers_casting(self):
11851185
atol=1e-2
11861186
)
11871187

1188+
def test_sglang_jax_1d_kv_bias_alignment(self):
1189+
"""Test 1-D KV bias alignment for sglang_jax rollout engine."""
1190+
src_key = "layers.0.attn.k_bias"
1191+
src_k_bias = jnp.arange(128, dtype=jnp.float32)
1192+
src = MockState({src_key: MockParam(src_k_bias)})
1193+
dst = MockState(
1194+
{src_key: MockParam(jnp.zeros((1024,), dtype=jnp.float32))}
1195+
)
1196+
mappings = {src_key: (src_key, None)}
1197+
1198+
result = utils.transfer_state_with_mappings(
1199+
src,
1200+
dst,
1201+
mappings,
1202+
rollout_engine="sglang_jax",
1203+
num_kv_heads=1,
1204+
head_dim=128,
1205+
)
1206+
1207+
self.assertEqual(result.params[src_key].shape, (1024,))
1208+
expected = jnp.tile(src_k_bias, 8)
1209+
self.assertTrue(jnp.allclose(result.params[src_key], expected))
1210+
11881211

11891212
class ResolveParallelismSizesTest(parameterized.TestCase):
11901213

tunix/generate/sglang_jax_sampler.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Sampler for sglang-jax-style autoregressive decoding using JAX and NNX models."""
1616

17+
import asyncio
1718
import dataclasses
1819
import math
1920
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -150,6 +151,16 @@ def update_params(
150151
transpose_keys=self.to_hf_transpose_keys,
151152
reshard_fn=reshard.reshard_pytree,
152153
rollout_engine="sglang_jax",
154+
num_kv_heads=(
155+
None
156+
if not self._model_runner
157+
else self._model_runner.model_config.get_total_num_kv_heads()
158+
),
159+
head_dim=(
160+
None
161+
if not self._model_runner
162+
else self._model_runner.model_config.head_dim
163+
),
153164
)
154165
new_model_state_leaves, _ = jax.tree_util.tree_flatten(new_state)
155166
self._model_runner.model_state_leaves = new_model_state_leaves
@@ -394,9 +405,6 @@ def wrap_generate():
394405
return future.result()
395406

396407

397-
import asyncio
398-
399-
400408
def get_or_create_event_loop():
401409
try:
402410
loop = asyncio.get_running_loop()

tunix/generate/utils.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -524,14 +524,31 @@ def _apply_transpose(
524524

525525

526526
def _align_shape(
527-
val: jnp.ndarray, tgt_shape: Tuple[int, ...], src_key: str
527+
val: jnp.ndarray,
528+
tgt_shape: Tuple[int, ...],
529+
src_key: str,
530+
rollout_engine: Optional[str] = None,
531+
**kwargs,
528532
) -> jnp.ndarray:
529533
"""Align source value shape to target shape through padding or repeating.
530534
535+
This function attempts to align the shape of a source JAX array (`val`) to a
536+
target shape (`tgt_shape`). It supports alignment by:
537+
1. Reshaping: If the product of dimensions matches, especially for attention
538+
biases and projections.
539+
2. Padding/Repeating: For attention-related weights, it can pad the head
540+
dimension or repeat along the number of heads dimension.
541+
3. Special Handling: Includes specific logic for 1-D KV biases in
542+
'sglang_jax' rollout.
543+
531544
Args:
532545
val: Source value.
533546
tgt_shape: Target shape.
534547
src_key: Source key for error messages.
548+
rollout_engine: Optional string indicating the rollout engine, used for
549+
special-casing certain alignments (e.g., 'sglang_jax').
550+
**kwargs: Additional keyword arguments, potentially containing metadata
551+
like 'num_kv_heads' and 'head_dim' for specific alignment logic.
535552
536553
Returns:
537554
Shape-aligned value.
@@ -596,6 +613,27 @@ def _align_shape(
596613
raise ShapeMismatchError(
597614
f'Rank mismatch for {src_key}: {val.shape} vs {tgt_shape}'
598615
)
616+
elif rollout_engine == 'sglang_jax' and re.compile(
617+
r'layers\..*\.attn\.(k|v)_bias'
618+
).match(src_key):
619+
logging.debug(
620+
'Handling 1-D KV bias for %s in SGLangJAX rollout.', src_key
621+
)
622+
assert tgt_shape[0] > val.shape[0] and tgt_shape[0] % val.shape[0] == 0, (
623+
f'Unexpected attention bias shape: {val.shape} and target shape:'
624+
f' {tgt_shape}'
625+
)
626+
repeat_factor = tgt_shape[0] // val.shape[0]
627+
logging.debug(
628+
'Replicating 1-D KV bias on %s: %s -> %s (repeat x%d per head)',
629+
src_key,
630+
val.shape,
631+
tgt_shape,
632+
repeat_factor,
633+
)
634+
val_2d = jnp.reshape(val, (kwargs['num_kv_heads'], kwargs['head_dim']))
635+
val_2d = jnp.repeat(val_2d, repeat_factor, axis=0)
636+
return jnp.reshape(val_2d, tgt_shape)
599637

600638
attention_patterns = [
601639
r'.*(q|k|v|o)_proj.*',
@@ -680,6 +718,7 @@ def transfer_state_with_mappings(
680718
transpose_keys=None,
681719
reshard_fn=None,
682720
rollout_engine=None,
721+
**kwargs,
683722
):
684723
"""Transfer state using mappings, with optional transpose and shard logic.
685724
@@ -695,6 +734,8 @@ def transfer_state_with_mappings(
695734
transpose_keys: A dictionary defining which keys to transpose and the
696735
corresponding axes to transpose.
697736
reshard_fn: A function to shard the value.
737+
rollout_engine: The name of the rollout engine being used.
738+
**kwargs: Additional keyword arguments.
698739
699740
Returns:
700741
The target state with the transferred values.
@@ -722,7 +763,7 @@ def transfer_state_with_mappings(
722763
unscanned_src_to_tgt_flat = _unroll_scanned_layers(src_state, src_to_tgt_map)
723764

724765
# Transfer values with transformations
725-
for (flat_src_key, tgt_key), (
766+
for (flat_src_key, _), (
726767
val,
727768
tgt_param,
728769
) in unscanned_src_to_tgt_flat.items():
@@ -734,7 +775,9 @@ def transfer_state_with_mappings(
734775
val = key_mapping_hook_fns[flat_src_key](val)
735776

736777
# Align shapes (padding/repeating as needed)
737-
val = _align_shape(val, tgt_param.value.shape, flat_src_key)
778+
val = _align_shape(
779+
val, tgt_param.value.shape, flat_src_key, rollout_engine, **kwargs
780+
)
738781

739782
# Cast to target dtype
740783
val = _apply_dtype_cast(val, tgt_param.value.dtype, flat_src_key)

0 commit comments

Comments
 (0)