Skip to content

Commit 2aa8b8a

Browse files
committed
fix
1 parent 2537c72 commit 2aa8b8a

3 files changed

Lines changed: 86 additions & 58 deletions

File tree

python/sgl_jax/srt/layers/attention/fla/group_rmsnorm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def __init__(
3030
):
3131
if hidden_size % num_groups != 0:
3232
raise ValueError("hidden_size must be divisible by num_groups")
33+
if mesh is not None:
34+
tp_size = mesh.shape.get("tensor", 1)
35+
if tp_size < num_groups:
36+
raise ValueError(
37+
"GroupRMSNorm requires tensor parallel size to be at least "
38+
f"num_groups to keep each RMS group intact, got tensor "
39+
f"parallel size={tp_size}, num_groups={num_groups}."
40+
)
3341

3442
self.hidden_size = hidden_size
3543
self.num_groups = num_groups

python/sgl_jax/srt/layers/attention/linear/lightning_backend.py

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import logging
1818
import math
19-
import os
2019
from typing import TYPE_CHECKING
2120

2221
import jax
@@ -32,19 +31,13 @@
3231
logger = logging.getLogger(__name__)
3332

3433
try:
35-
from sgl_jax.srt.kernels.simple_gla.native import (
36-
naive_gla_decode,
37-
naive_gla_prefill,
38-
)
3934
from sgl_jax.srt.kernels.simple_gla.simple_gla import (
4035
fused_recurrent_simple_gla,
4136
simple_gla_fwd,
4237
)
4338
except ModuleNotFoundError:
4439
simple_gla_fwd = None
4540
fused_recurrent_simple_gla = None
46-
naive_gla_decode = None
47-
naive_gla_prefill = None
4841

4942
if TYPE_CHECKING:
5043
from sgl_jax.srt.layers.radix_lightning_attention import RadixLightningAttention
@@ -107,7 +100,6 @@ def __init__(
107100
"""
108101
super().__init__(mesh=mesh)
109102
self.chunk_size = chunk_size
110-
self.use_native_gla = os.environ.get("SGLANG_JAX_GLA_BACKEND", "").lower() == "native"
111103
if (
112104
linear_recurrent_layer_ids is not None
113105
and num_hidden_layers is not None
@@ -212,7 +204,7 @@ def _forward_decode(
212204
slope: jnp.ndarray,
213205
) -> tuple[jax.Array, jax.Array]:
214206
"""Decode forward using shard_map."""
215-
if fused_recurrent_simple_gla is None or naive_gla_decode is None:
207+
if fused_recurrent_simple_gla is None:
216208
raise ImportError("simple_gla kernel is required for GLA decode")
217209

218210
ssm_states = ssm_states.astype(jnp.float32)
@@ -221,25 +213,15 @@ def _decode_fn(q_local, k_local, v_local, gamma, h0):
221213
q_d = q_local[:, None, :, :]
222214
k_d = k_local[:, None, :, :]
223215
v_d = v_local[:, None, :, :]
224-
if self.use_native_gla:
225-
output_d, new_state = naive_gla_decode(
226-
q_d,
227-
k_d,
228-
v_d,
229-
g_gamma=gamma,
230-
h0=h0,
231-
scale=None,
232-
)
233-
else:
234-
output_d, new_state = fused_recurrent_simple_gla(
235-
q_d,
236-
k_d,
237-
v_d,
238-
g_gamma=gamma,
239-
initial_state=h0,
240-
output_final_state=True,
241-
scale=None,
242-
)
216+
output_d, new_state = fused_recurrent_simple_gla(
217+
q_d,
218+
k_d,
219+
v_d,
220+
g_gamma=gamma,
221+
initial_state=h0,
222+
output_final_state=True,
223+
scale=None,
224+
)
243225
return output_d[:, 0, :, :], new_state
244226

245227
output, new_state = jax.shard_map(
@@ -270,36 +252,25 @@ def _forward_extend(
270252
slope: jnp.ndarray,
271253
) -> tuple[jax.Array, jax.Array]:
272254
"""Extend forward using shard_map."""
273-
if simple_gla_fwd is None or naive_gla_prefill is None:
255+
if simple_gla_fwd is None:
274256
raise ImportError("simple_gla kernel is required for GLA prefill")
275257

276258
cu_seqlens = self.forward_metadata.cu_q_lens
277259
ssm_states = ssm_states.astype(jnp.float32)
278260
chunk_size = self.chunk_size
279261

280262
def _prefill_fn(q_local, k_local, v_local, gamma, h0, cu_seqlens_p):
281-
if self.use_native_gla:
282-
output, ht = naive_gla_prefill(
283-
q_local[None],
284-
k_local[None],
285-
v_local[None],
286-
g_gamma=gamma,
287-
h0=h0,
288-
cu_seqlens=cu_seqlens_p,
289-
scale=None,
290-
)
291-
else:
292-
output, ht = simple_gla_fwd(
293-
q_local[None],
294-
k_local[None],
295-
v_local[None],
296-
g_gamma=gamma,
297-
h0=h0,
298-
cu_seqlens_dev=cu_seqlens_p,
299-
scale=None,
300-
use_ht=True,
301-
chunk_size=chunk_size,
302-
)
263+
output, ht = simple_gla_fwd(
264+
q_local[None],
265+
k_local[None],
266+
v_local[None],
267+
g_gamma=gamma,
268+
h0=h0,
269+
cu_seqlens_dev=cu_seqlens_p,
270+
scale=None,
271+
use_ht=True,
272+
chunk_size=chunk_size,
273+
)
303274
return output[0], ht
304275

305276
output, new_state = jax.shard_map(

python/sgl_jax/test/layers/test_group_rmsnorm.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import jax
12
import jax.numpy as jnp
23
import numpy as np
4+
import pytest
5+
from jax.sharding import AxisType, Mesh, NamedSharding
6+
from jax.sharding import PartitionSpec as P
37

48
from sgl_jax.srt.layers.attention.fla.group_rmsnorm import GroupRMSNorm
59

@@ -39,17 +43,42 @@ def _make_weight(rng, hidden_size=HIDDEN_SIZE):
3943
return rng.standard_normal(hidden_size).astype(np.float32)
4044

4145

46+
def _make_mesh(num_groups=NUM_GROUPS):
47+
devices = np.array(jax.devices())
48+
if devices.size < num_groups:
49+
pytest.skip(
50+
f"GroupRMSNorm sharded test requires at least {num_groups} devices, got {devices.size}"
51+
)
52+
return Mesh(
53+
devices[:num_groups].reshape(1, num_groups),
54+
axis_names=("data", "tensor"),
55+
axis_types=(AxisType.Explicit, AxisType.Explicit),
56+
)
57+
58+
4259
def _make_jax_model(hidden_size=HIDDEN_SIZE, num_groups=NUM_GROUPS, weight=None):
4360
"""Create a JAX GroupRMSNorm model, optionally with custom weight."""
44-
model = GroupRMSNorm(hidden_size, num_groups=num_groups, epsilon=EPSILON)
61+
mesh = _make_mesh(num_groups)
62+
with jax.set_mesh(mesh):
63+
model = GroupRMSNorm(
64+
hidden_size,
65+
num_groups=num_groups,
66+
epsilon=EPSILON,
67+
kernel_axes=("tensor",),
68+
mesh=mesh,
69+
)
4570
if weight is not None:
46-
model.weight[...] = jnp.array(weight)
71+
model.weight[...] = jax.device_put(
72+
jnp.array(weight),
73+
NamedSharding(mesh, P("tensor")),
74+
)
4775
return model
4876

4977

5078
def _run_jax(model, input_np, dtype=jnp.float32):
5179
"""Run JAX model and return numpy array."""
52-
return np.array(model(jnp.array(input_np, dtype=dtype)))
80+
with jax.set_mesh(model.mesh):
81+
return np.array(model(jnp.array(input_np, dtype=dtype)))
5382

5483

5584
class TestGroupRMSNorm:
@@ -58,18 +87,18 @@ class TestGroupRMSNorm:
5887
def test_output_shape_matches_input(self):
5988
"""Output shape must match input shape."""
6089
rng = np.random.default_rng(SEED)
61-
input_data = jnp.array(_make_input(rng, (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)))
90+
input_data = _make_input(rng, (BATCH_SIZE * SEQ_LEN, HIDDEN_SIZE))
6291

6392
model = _make_jax_model()
64-
output = model(input_data)
93+
output = _run_jax(model, input_data)
6594

6695
assert output.shape == input_data.shape
6796

6897
def test_groups_are_independent(self):
6998
"""Modifying one group must not affect other groups' outputs."""
7099
rng = np.random.default_rng(SEED)
71100

72-
input_original = _make_input(rng, (1, 1, HIDDEN_SIZE))
101+
input_original = _make_input(rng, (1, HIDDEN_SIZE))
73102
input_modified = input_original.copy()
74103
input_modified[..., :GROUP_SIZE] = _make_input(rng, (GROUP_SIZE,)) # perturb group 0 only
75104

@@ -93,11 +122,31 @@ def test_groups_are_independent(self):
93122
def test_weight_participates_in_computation(self):
94123
"""Weight parameter must participate in computation correctly."""
95124
rng = np.random.default_rng(SEED)
96-
input_data = _make_input(rng, (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE))
125+
input_data = _make_input(rng, (BATCH_SIZE * SEQ_LEN, HIDDEN_SIZE))
97126
weight = _make_weight(rng)
98127

99128
model = _make_jax_model(weight=weight)
100129
jax_output = _run_jax(model, input_data)
101130
expected = _numpy_group_rmsnorm_fp64(input_data, weight, NUM_GROUPS, EPSILON)
102131

103132
np.testing.assert_allclose(jax_output, expected, rtol=FP32_RTOL, atol=FP32_ATOL)
133+
134+
def test_rejects_tp_smaller_than_num_groups(self):
135+
"""Tensor parallelism must be at least the number of RMS groups."""
136+
mesh = Mesh(
137+
np.array(jax.devices()[:1]).reshape(1, 1),
138+
axis_names=("data", "tensor"),
139+
axis_types=(AxisType.Explicit, AxisType.Explicit),
140+
)
141+
142+
with pytest.raises(
143+
ValueError,
144+
match="tensor parallel size.*num_groups",
145+
):
146+
GroupRMSNorm(
147+
HIDDEN_SIZE,
148+
num_groups=NUM_GROUPS,
149+
epsilon=EPSILON,
150+
kernel_axes=("tensor",),
151+
mesh=mesh,
152+
)

0 commit comments

Comments
 (0)