1+ import jax
12import jax .numpy as jnp
23import numpy as np
4+ import pytest
5+ from jax .sharding import AxisType , Mesh , NamedSharding
6+ from jax .sharding import PartitionSpec as P
37
48from sgl_jax .srt .layers .attention .fla .group_rmsnorm import GroupRMSNorm
59
@@ -39,17 +43,42 @@ def _make_weight(rng, hidden_size=HIDDEN_SIZE):
3943 return rng .standard_normal (hidden_size ).astype (np .float32 )
4044
4145
46+ def _make_mesh (num_groups = NUM_GROUPS ):
47+ devices = np .array (jax .devices ())
48+ if devices .size < num_groups :
49+ pytest .skip (
50+ f"GroupRMSNorm sharded test requires at least { num_groups } devices, got { devices .size } "
51+ )
52+ return Mesh (
53+ devices [:num_groups ].reshape (1 , num_groups ),
54+ axis_names = ("data" , "tensor" ),
55+ axis_types = (AxisType .Explicit , AxisType .Explicit ),
56+ )
57+
58+
4259def _make_jax_model (hidden_size = HIDDEN_SIZE , num_groups = NUM_GROUPS , weight = None ):
4360 """Create a JAX GroupRMSNorm model, optionally with custom weight."""
44- model = GroupRMSNorm (hidden_size , num_groups = num_groups , epsilon = EPSILON )
61+ mesh = _make_mesh (num_groups )
62+ with jax .set_mesh (mesh ):
63+ model = GroupRMSNorm (
64+ hidden_size ,
65+ num_groups = num_groups ,
66+ epsilon = EPSILON ,
67+ kernel_axes = ("tensor" ,),
68+ mesh = mesh ,
69+ )
4570 if weight is not None :
46- model .weight [...] = jnp .array (weight )
71+ model .weight [...] = jax .device_put (
72+ jnp .array (weight ),
73+ NamedSharding (mesh , P ("tensor" )),
74+ )
4775 return model
4876
4977
5078def _run_jax (model , input_np , dtype = jnp .float32 ):
5179 """Run JAX model and return numpy array."""
52- return np .array (model (jnp .array (input_np , dtype = dtype )))
80+ with jax .set_mesh (model .mesh ):
81+ return np .array (model (jnp .array (input_np , dtype = dtype )))
5382
5483
5584class TestGroupRMSNorm :
@@ -58,18 +87,18 @@ class TestGroupRMSNorm:
5887 def test_output_shape_matches_input (self ):
5988 """Output shape must match input shape."""
6089 rng = np .random .default_rng (SEED )
61- input_data = jnp . array ( _make_input (rng , (BATCH_SIZE , SEQ_LEN , HIDDEN_SIZE ) ))
90+ input_data = _make_input (rng , (BATCH_SIZE * SEQ_LEN , HIDDEN_SIZE ))
6291
6392 model = _make_jax_model ()
64- output = model ( input_data )
93+ output = _run_jax ( model , input_data )
6594
6695 assert output .shape == input_data .shape
6796
6897 def test_groups_are_independent (self ):
6998 """Modifying one group must not affect other groups' outputs."""
7099 rng = np .random .default_rng (SEED )
71100
72- input_original = _make_input (rng , (1 , 1 , HIDDEN_SIZE ))
101+ input_original = _make_input (rng , (1 , HIDDEN_SIZE ))
73102 input_modified = input_original .copy ()
74103 input_modified [..., :GROUP_SIZE ] = _make_input (rng , (GROUP_SIZE ,)) # perturb group 0 only
75104
@@ -93,11 +122,31 @@ def test_groups_are_independent(self):
93122 def test_weight_participates_in_computation (self ):
94123 """Weight parameter must participate in computation correctly."""
95124 rng = np .random .default_rng (SEED )
96- input_data = _make_input (rng , (BATCH_SIZE , SEQ_LEN , HIDDEN_SIZE ))
125+ input_data = _make_input (rng , (BATCH_SIZE * SEQ_LEN , HIDDEN_SIZE ))
97126 weight = _make_weight (rng )
98127
99128 model = _make_jax_model (weight = weight )
100129 jax_output = _run_jax (model , input_data )
101130 expected = _numpy_group_rmsnorm_fp64 (input_data , weight , NUM_GROUPS , EPSILON )
102131
103132 np .testing .assert_allclose (jax_output , expected , rtol = FP32_RTOL , atol = FP32_ATOL )
133+
134+ def test_rejects_tp_smaller_than_num_groups (self ):
135+ """Tensor parallelism must be at least the number of RMS groups."""
136+ mesh = Mesh (
137+ np .array (jax .devices ()[:1 ]).reshape (1 , 1 ),
138+ axis_names = ("data" , "tensor" ),
139+ axis_types = (AxisType .Explicit , AxisType .Explicit ),
140+ )
141+
142+ with pytest .raises (
143+ ValueError ,
144+ match = "tensor parallel size.*num_groups" ,
145+ ):
146+ GroupRMSNorm (
147+ HIDDEN_SIZE ,
148+ num_groups = NUM_GROUPS ,
149+ epsilon = EPSILON ,
150+ kernel_axes = ("tensor" ,),
151+ mesh = mesh ,
152+ )
0 commit comments