Skip to content

Commit b6d7606

Browse files
committed
fix: wan2.1
1 parent 21a817c commit b6d7606

7 files changed

Lines changed: 102 additions & 35 deletions

File tree

python/sgl_jax/srt/kernels/update_kv_cache/update_kv_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def kv_cache_update_impl(
185185
@partial(
186186
jax.jit,
187187
static_argnames=["page_size", "num_slices_per_block", "kv_partition_axis"],
188+
donate_argnums=(2,),
188189
)
189190
def kv_cache_update(
190191
new_kv: jax.Array, # [total_num_token, num_kv_heads, head_dim]

python/sgl_jax/srt/layers/attention/native_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __call__(
9292
P("data"), # extend_prefix_lens
9393
P("data"), # extend_seq_lens
9494
)
95-
out_specs = P("data", "tensor", None) # attn_output
95+
out_specs = P("data", "tensor") # attn_output: [num_tokens, hidden_size]
9696

9797
attn_output = jax.shard_map(
9898
lambda q_local, k_local, v_local, seq_lens_local, loc_local, prefix_lens_local, extend_lens_local: forward_attention(

python/sgl_jax/srt/mem_cache/memory_pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import logging
33
import time
4+
from functools import partial
45

56
import jax
67
import jax.numpy as jnp
@@ -784,6 +785,7 @@ def update_fused_kv_cache_vectorized(
784785
by grouping contiguous tokens into page-sized chunks for efficient updates.
785786
"""
786787

788+
@partial(jax.jit, donate_argnums=(2,))
787789
@jax.shard_map(
788790
in_specs=(
789791
# fused_kv: sharded by data (tokens) and tensor (heads)

python/sgl_jax/srt/models/umt5.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -261,48 +261,98 @@ def __call__(
261261

262262
def _native_attention(self, q, k, v, forward_batch: ForwardBatch):
263263
"""Native attention for encoder/cross-attention with T5 position bias."""
264-
num_tokens, hidden = q.shape[0], q.shape[-1]
265-
head_dim = hidden // self.n_heads
266-
267-
# Reshape to [heads, tokens, head_dim]
268-
def to_heads(x):
269-
n_tok = x.shape[0]
270-
return jnp.transpose(x.reshape(n_tok, self.n_heads, head_dim), (1, 0, 2))
271-
272-
q_h, k_h, v_h = to_heads(q), to_heads(k), to_heads(v)
273-
274-
# Compute scores in float32
275-
scores = jnp.einsum("hqd,hkd->hqk", q_h.astype(jnp.float32), k_h.astype(jnp.float32))
264+
hidden = q.shape[-1]
265+
head_dim = self.d_kv # T5 uses d_kv as head dimension, not hidden // n_heads
266+
n_heads = self.n_heads # Capture as local variable for closure
267+
is_cross_attn = self.is_cross_attention # Capture as local variable
268+
has_rel_bias = hasattr(self, "rel_bias") # Capture as local variable
269+
270+
# Debug: print dimensions
271+
jax.debug.print(
272+
"UMT5 _native_attention: q.shape={q_shape}, hidden={hidden}, d_kv={d_kv}, n_heads={n_heads}, inner_dim={inner_dim}",
273+
q_shape=q.shape,
274+
hidden=hidden,
275+
d_kv=head_dim,
276+
n_heads=n_heads,
277+
inner_dim=self.inner_dim,
278+
)
276279

277280
# Get sequence lengths
278281
q_lens = getattr(forward_batch, "extend_seq_lens", forward_batch.seq_lens)
279282
# Fallback if seq_lens is None: assume single sequence
280283
if q_lens is None:
281284
q_lens = jnp.array([q.shape[0]], dtype=jnp.int32)
282285

283-
# Add position bias for self-attention (T5-specific)
284-
if not self.is_cross_attention and hasattr(self, "rel_bias"):
285-
pos_bias = self._compute_position_bias(q_lens, q.shape[0], k.shape[0])
286-
scores = scores + pos_bias.astype(jnp.float32)
286+
rel_bias_weight = self.rel_bias.embedding.value if hasattr(self, "rel_bias") else None
287287

288-
# Apply masking
289288
kv_lens = (
290289
getattr(forward_batch, "encoder_seq_lens", q_lens)
291290
if self.is_cross_attention
292291
else q_lens
293292
)
294293
is_causal = self.is_decoder and not self.is_cross_attention
295294

296-
# Apply block_diagonal_mask
297-
scores = _apply_block_diagonal_mask(scores, q_lens, kv_lens, is_causal=is_causal)
295+
# Wrap computation in shard_map for data parallelism
296+
in_specs = (
297+
P("data", "tensor"), # q
298+
P("data", "tensor"), # k
299+
P("data", "tensor"), # v
300+
P("data"), # q_lens
301+
P("data"), # kv_lens
302+
P(None, "tensor"), # rel_bias_weight
303+
)
304+
out_specs = P("data", "tensor")
305+
306+
def _compute_attention(q_local, k_local, v_local, q_lens_local, kv_lens_local, rel_weight):
307+
# Debug: print local shapes inside shard_map
308+
jax.debug.print(
309+
"Inside shard_map: q_local.shape={q_shape}, n_heads={n_heads}, head_dim={head_dim}",
310+
q_shape=q_local.shape,
311+
n_heads=n_heads,
312+
head_dim=head_dim,
313+
)
314+
local_n_heads = q_local.shape[-1] // head_dim
315+
local_hidden = q_local.shape[-1]
316+
317+
# Reshape to [heads, tokens, head_dim]
318+
def to_heads(x):
319+
n_tok = x.shape[0]
320+
return jnp.transpose(x.reshape(n_tok, local_n_heads, head_dim), (1, 0, 2))
321+
322+
q_h, k_h, v_h = to_heads(q_local), to_heads(k_local), to_heads(v_local)
323+
324+
# Compute scores in float32
325+
scores = jnp.einsum("hqd,hkd->hqk", q_h.astype(jnp.float32), k_h.astype(jnp.float32))
326+
327+
# Add position bias for self-attention (T5-specific)
328+
if not is_cross_attn and has_rel_bias:
329+
pos_bias = self._compute_position_bias(
330+
q_lens_local, q_local.shape[0], k_local.shape[0], rel_weight
331+
)
332+
scores = scores + pos_bias.astype(jnp.float32)
333+
334+
# Apply block_diagonal_mask
335+
scores = _apply_block_diagonal_mask(
336+
scores, q_lens_local, kv_lens_local, is_causal=is_causal
337+
)
338+
339+
# Softmax and weighted sum
340+
weights = jax.nn.softmax(scores, axis=-1)
341+
out = jnp.einsum("hqk,hkd->hqd", weights, v_h.astype(jnp.float32))
342+
343+
return jnp.transpose(out, (1, 0, 2)).reshape(q_local.shape[0], local_hidden)
298344

299-
# Softmax and weighted sum
300-
weights = jax.nn.softmax(scores, axis=-1)
301-
out = jnp.einsum("hqk,hkd->hqd", weights, v_h.astype(jnp.float32))
345+
result = jax.shard_map(
346+
_compute_attention,
347+
mesh=self.mesh,
348+
in_specs=in_specs,
349+
out_specs=out_specs,
350+
check_vma=False,
351+
)(q, k, v, q_lens, kv_lens, rel_bias_weight)
302352

303-
return jnp.transpose(out, (1, 0, 2)).reshape(num_tokens, hidden)
353+
return result
304354

305-
def _compute_position_bias(self, seq_lens, q_len, k_len):
355+
def _compute_position_bias(self, seq_lens, q_len, k_len, rel_weight):
306356
"""Compute T5 position bias [heads, q_len, k_len]."""
307357
starts = jnp.cumsum(seq_lens) - seq_lens
308358
indicators = jnp.zeros(q_len, dtype=jnp.int32).at[starts].set(1)
@@ -318,7 +368,8 @@ def _compute_position_bias(self, seq_lens, q_len, k_len):
318368
num_buckets=self.num_buckets,
319369
max_distance=self.max_distance,
320370
)
321-
return jnp.transpose(self.rel_bias(buckets), (2, 0, 1))
371+
bias = rel_weight[buckets]
372+
return jnp.transpose(bias, (2, 0, 1))
322373

323374

324375
# =============================================================================

python/sgl_jax/test/mem_cache/test_kv_cache.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def test_kv_cache_update_page_size_1(self):
117117
total_tokens = 16
118118
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False)
119119

120-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=1)
120+
updated_k_cache, updated_v_cache = update_kv_cache(
121+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=1
122+
)
121123

122124
# Expected result
123125
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -132,7 +134,9 @@ def test_kv_cache_update_page_size_1_with_padding(self):
132134
total_tokens = 12
133135
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=True)
134136

135-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=1)
137+
updated_k_cache, updated_v_cache = update_kv_cache(
138+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=1
139+
)
136140

137141
# Expected result (should ignore padding tokens where loc == -1)
138142
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -148,7 +152,9 @@ def test_kv_cache_update_page_size_4(self):
148152
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False)
149153

150154
# Test with page_size=4
151-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=4)
155+
updated_k_cache, updated_v_cache = update_kv_cache(
156+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=4
157+
)
152158

153159
# Expected result
154160
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -163,7 +169,9 @@ def test_kv_cache_update_page_size_4_with_padding(self):
163169
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=True)
164170

165171
# Test with page_size=4
166-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=4)
172+
updated_k_cache, updated_v_cache = update_kv_cache(
173+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=4
174+
)
167175

168176
# Expected result (should ignore padding tokens where loc == -1)
169177
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -179,7 +187,9 @@ def test_kv_cache_update_page_size_8_contiguous(self):
179187
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False)
180188

181189
# Test with page_size=8
182-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=8)
190+
updated_k_cache, updated_v_cache = update_kv_cache(
191+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=8
192+
)
183193

184194
# Expected result
185195
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -204,7 +214,9 @@ def test_all_padding_tokens(self):
204214
original_v_cache = v_cache.copy()
205215

206216
# Test both approaches
207-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=8)
217+
updated_k_cache, updated_v_cache = update_kv_cache(
218+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=8
219+
)
208220

209221
# Cache should remain unchanged since all tokens are padding
210222
self.assertTrue(jnp.allclose(updated_k_cache, original_k_cache))
@@ -266,7 +278,7 @@ def test_kv_cache_update_multiple_segments_with_padding(self):
266278
for page_size in [1, 2, 4, 8]:
267279
with self.subTest(page_size=page_size):
268280
updated_k_cache, updated_v_cache = update_kv_cache(
269-
k, v, loc, k_cache, v_cache, page_size=page_size
281+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=page_size
270282
)
271283

272284
# Expected result

python/sgl_jax/test/mem_cache/test_swa_radix_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class TestSWARadixCache(unittest.TestCase):
2727
def setUp(self):
2828
# Keep KV sizes small to make tests light-weight
2929
self.devices = jax.devices()
30-
self.mesh = Mesh([self.devices[0]], axis_names=("tensor",))
30+
# Create mesh with both 'data' and 'tensor' axes for DP compatibility
31+
self.mesh = Mesh(np.array([self.devices[0]]).reshape(1, 1), axis_names=("data", "tensor"))
3132

3233
# Small buffers to avoid heavy allocations
3334
self.kv_head_num = 1

python/sgl_jax/test/test_flashattention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def create_test_data(
140140
causal=True,
141141
input_ids=None,
142142
model_config=None,
143-
max_total_token_size=710016,
143+
max_total_token_size=100000,
144144
):
145145
"""Create a real ForwardBatch for testing."""
146146
assert mode in ["prefill", "decode"]

0 commit comments

Comments
 (0)