Skip to content

Commit e32453d

Browse files
aolemilaclaude
andcommitted
test(kda): hoist KDAAttnBackendForTest shim to test_utils.py
Both KDA test files were carrying an identical copy of the `pool=` → `recurrent_state_pool=` translation shim added in 466afff. Move it to test_utils.py and import from both, dropping the local underscore prefix since it's now a shared helper. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 466afff commit e32453d

3 files changed

Lines changed: 29 additions & 53 deletions

File tree

python/sgl_jax/test/test_kda_attention.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,12 @@
1515
from sgl_jax.srt.mem_cache.recurrent_state_pool import RecurrentStatePool
1616
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
1717
from sgl_jax.srt.utils.mesh_utils import create_device_mesh
18+
from sgl_jax.test.test_utils import KDAAttnBackendForTest
1819

1920
mesh = create_device_mesh(ici_parallelism=[1, -1], dcn_parallelism=[1, 1])
2021
jax.sharding.set_mesh(mesh)
2122

2223

23-
class _KDAAttnBackendForTest:
24-
"""Test wrapper that translates `pool=` kwarg to `recurrent_state_pool=`.
25-
26-
Production routes through HybridLinearAttnBackend, which accepts `pool=`
27-
(RadixLinearAttention's call convention) and forwards it to KDA as
28-
`recurrent_state_pool=`. These tests assign the raw KDA backend as
29-
`forward_batch.attn_backend`, bypassing that wrapper, so we replicate the
30-
same translation here.
31-
"""
32-
33-
def __init__(self, backend):
34-
object.__setattr__(self, "_backend", backend)
35-
36-
def __call__(self, *args, **kwargs):
37-
if "pool" in kwargs:
38-
kwargs["recurrent_state_pool"] = kwargs.pop("pool")
39-
return self._backend(*args, **kwargs)
40-
41-
def __getattr__(self, name):
42-
return getattr(self._backend, name)
43-
44-
def __setattr__(self, name, value):
45-
setattr(self._backend, name, value)
46-
47-
4824
def _scaled_randn(rng: np.random.Generator, shape, scale: float = 0.1) -> np.ndarray:
4925
# scale=0.1 is a test-only hack: it shrinks the recurrent state so bf16 noise
5026
# in the delta-rule update fits the global atol=1e-2 (shared with flashattn).
@@ -348,7 +324,7 @@ def conv_weight():
348324
extend_prefix_lens = np.zeros(batch_size, dtype=np.int32) if mode == "prefill" else None
349325
has_initial_state_np = np.asarray(has_initial_state_per_req, dtype=np.bool_)
350326

351-
backend = _KDAAttnBackendForTest(KDAAttnBackend(mesh=test_mesh))
327+
backend = KDAAttnBackendForTest(KDAAttnBackend(mesh=test_mesh))
352328

353329
mwb = ModelWorkerBatch(
354330
bid=1,

python/sgl_jax/test/test_kda_attention_dp.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
1717
from sgl_jax.srt.utils.common_utils import pad_to_bucket
1818
from sgl_jax.srt.utils.mesh_utils import create_device_mesh
19-
from sgl_jax.test.test_utils import CustomTestCase
19+
from sgl_jax.test.test_utils import CustomTestCase, KDAAttnBackendForTest
2020

2121

2222
def _scaled_randn(rng: np.random.Generator, shape, scale: float = 0.1) -> np.ndarray:
@@ -27,31 +27,6 @@ def _scaled_randn(rng: np.random.Generator, shape, scale: float = 0.1) -> np.nda
2727
return rng.standard_normal(shape).astype(np.float32) * scale
2828

2929

30-
class _KDAAttnBackendForTest:
31-
"""Test wrapper that translates `pool=` kwarg to `recurrent_state_pool=`.
32-
33-
Production routes through HybridLinearAttnBackend, which accepts `pool=`
34-
(RadixLinearAttention's call convention) and forwards it to KDA as
35-
`recurrent_state_pool=`. These tests assign the raw KDA backend as
36-
`forward_batch.attn_backend`, bypassing that wrapper, so we replicate the
37-
same translation here.
38-
"""
39-
40-
def __init__(self, backend):
41-
object.__setattr__(self, "_backend", backend)
42-
43-
def __call__(self, *args, **kwargs):
44-
if "pool" in kwargs:
45-
kwargs["recurrent_state_pool"] = kwargs.pop("pool")
46-
return self._backend(*args, **kwargs)
47-
48-
def __getattr__(self, name):
49-
return getattr(self._backend, name)
50-
51-
def __setattr__(self, name, value):
52-
setattr(self._backend, name, value)
53-
54-
5530
# Reference baselines duplicated from test_kda_attention.py — keep in sync.
5631

5732

@@ -420,7 +395,7 @@ def conv_weight():
420395
req_pool_indices_cpu = np.arange(total_bs, dtype=np.int32)
421396

422397
real_bs_per_dp = [len(lens_per_rank.get(r, [])) for r in range(dp_size)]
423-
backend = _KDAAttnBackendForTest(KDAAttnBackend(mesh=mesh))
398+
backend = KDAAttnBackendForTest(KDAAttnBackend(mesh=mesh))
424399

425400
mwb = ModelWorkerBatch(
426401
bid=1,

python/sgl_jax/test/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,3 +788,28 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
788788
rouge_l_scores.append(fmeasure)
789789

790790
return rouge_l_scores
791+
792+
793+
class KDAAttnBackendForTest:
794+
"""Test wrapper that translates `pool=` kwarg to `recurrent_state_pool=`.
795+
796+
Production routes through HybridLinearAttnBackend, which accepts `pool=`
797+
(RadixLinearAttention's call convention) and forwards it to the linear
798+
sub-backend as `recurrent_state_pool=`. Tests that assign a raw linear
799+
backend (e.g. KDAAttnBackend) as `forward_batch.attn_backend` bypass the
800+
wrapper, so this shim replicates the same translation.
801+
"""
802+
803+
def __init__(self, backend):
804+
object.__setattr__(self, "_backend", backend)
805+
806+
def __call__(self, *args, **kwargs):
807+
if "pool" in kwargs:
808+
kwargs["recurrent_state_pool"] = kwargs.pop("pool")
809+
return self._backend(*args, **kwargs)
810+
811+
def __getattr__(self, name):
812+
return getattr(self._backend, name)
813+
814+
def __setattr__(self, name, value):
815+
setattr(self._backend, name, value)

0 commit comments

Comments
 (0)