Skip to content

Commit 19cec01

Browse files
committed
fix: wan2.1
1 parent 21a817c commit 19cec01

7 files changed

Lines changed: 105 additions & 38 deletions

File tree

python/sgl_jax/srt/kernels/update_kv_cache/update_kv_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def kv_cache_update_impl(
185185
@partial(
186186
jax.jit,
187187
static_argnames=["page_size", "num_slices_per_block", "kv_partition_axis"],
188+
donate_argnums=(2,),
188189
)
189190
def kv_cache_update(
190191
new_kv: jax.Array, # [total_num_token, num_kv_heads, head_dim]

python/sgl_jax/srt/layers/attention/native_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __call__(
9292
P("data"), # extend_prefix_lens
9393
P("data"), # extend_seq_lens
9494
)
95-
out_specs = P("data", "tensor", None) # attn_output
95+
out_specs = P("data", "tensor") # attn_output: [num_tokens, hidden_size]
9696

9797
attn_output = jax.shard_map(
9898
lambda q_local, k_local, v_local, seq_lens_local, loc_local, prefix_lens_local, extend_lens_local: forward_attention(

python/sgl_jax/srt/mem_cache/memory_pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import logging
33
import time
4+
from functools import partial
45

56
import jax
67
import jax.numpy as jnp
@@ -784,6 +785,7 @@ def update_fused_kv_cache_vectorized(
784785
by grouping contiguous tokens into page-sized chunks for efficient updates.
785786
"""
786787

788+
@partial(jax.jit, donate_argnums=(2,))
787789
@jax.shard_map(
788790
in_specs=(
789791
# fused_kv: sharded by data (tokens) and tensor (heads)

python/sgl_jax/srt/models/umt5.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import jax
1616
import jax.numpy as jnp
1717
from flax import nnx
18-
from jax.sharding import NamedSharding
1918
from jax.sharding import PartitionSpec as P
2019
from transformers import UMT5Config
2120

@@ -261,48 +260,98 @@ def __call__(
261260

262261
def _native_attention(self, q, k, v, forward_batch: ForwardBatch):
263262
"""Native attention for encoder/cross-attention with T5 position bias."""
264-
num_tokens, hidden = q.shape[0], q.shape[-1]
265-
head_dim = hidden // self.n_heads
266-
267-
# Reshape to [heads, tokens, head_dim]
268-
def to_heads(x):
269-
n_tok = x.shape[0]
270-
return jnp.transpose(x.reshape(n_tok, self.n_heads, head_dim), (1, 0, 2))
271-
272-
q_h, k_h, v_h = to_heads(q), to_heads(k), to_heads(v)
273-
274-
# Compute scores in float32
275-
scores = jnp.einsum("hqd,hkd->hqk", q_h.astype(jnp.float32), k_h.astype(jnp.float32))
263+
hidden = q.shape[-1]
264+
head_dim = self.d_kv # T5 uses d_kv as head dimension, not hidden // n_heads
265+
n_heads = self.n_heads # Capture as local variable for closure
266+
is_cross_attn = self.is_cross_attention # Capture as local variable
267+
has_rel_bias = hasattr(self, "rel_bias") # Capture as local variable
268+
269+
# Debug: print dimensions
270+
jax.debug.print(
271+
"UMT5 _native_attention: q.shape={q_shape}, hidden={hidden}, d_kv={d_kv}, n_heads={n_heads}, inner_dim={inner_dim}",
272+
q_shape=q.shape,
273+
hidden=hidden,
274+
d_kv=head_dim,
275+
n_heads=n_heads,
276+
inner_dim=self.inner_dim,
277+
)
276278

277279
# Get sequence lengths
278280
q_lens = getattr(forward_batch, "extend_seq_lens", forward_batch.seq_lens)
279281
# Fallback if seq_lens is None: assume single sequence
280282
if q_lens is None:
281283
q_lens = jnp.array([q.shape[0]], dtype=jnp.int32)
282284

283-
# Add position bias for self-attention (T5-specific)
284-
if not self.is_cross_attention and hasattr(self, "rel_bias"):
285-
pos_bias = self._compute_position_bias(q_lens, q.shape[0], k.shape[0])
286-
scores = scores + pos_bias.astype(jnp.float32)
285+
rel_bias_weight = self.rel_bias.embedding.value if hasattr(self, "rel_bias") else None
287286

288-
# Apply masking
289287
kv_lens = (
290288
getattr(forward_batch, "encoder_seq_lens", q_lens)
291289
if self.is_cross_attention
292290
else q_lens
293291
)
294292
is_causal = self.is_decoder and not self.is_cross_attention
295293

296-
# Apply block_diagonal_mask
297-
scores = _apply_block_diagonal_mask(scores, q_lens, kv_lens, is_causal=is_causal)
294+
# Wrap computation in shard_map for data parallelism
295+
in_specs = (
296+
P("data", "tensor"), # q
297+
P("data", "tensor"), # k
298+
P("data", "tensor"), # v
299+
P("data"), # q_lens
300+
P("data"), # kv_lens
301+
P(None, "tensor"), # rel_bias_weight
302+
)
303+
out_specs = P("data", "tensor")
304+
305+
def _compute_attention(q_local, k_local, v_local, q_lens_local, kv_lens_local, rel_weight):
306+
# Debug: print local shapes inside shard_map
307+
jax.debug.print(
308+
"Inside shard_map: q_local.shape={q_shape}, n_heads={n_heads}, head_dim={head_dim}",
309+
q_shape=q_local.shape,
310+
n_heads=n_heads,
311+
head_dim=head_dim,
312+
)
313+
local_n_heads = q_local.shape[-1] // head_dim
314+
local_hidden = q_local.shape[-1]
315+
316+
# Reshape to [heads, tokens, head_dim]
317+
def to_heads(x):
318+
n_tok = x.shape[0]
319+
return jnp.transpose(x.reshape(n_tok, local_n_heads, head_dim), (1, 0, 2))
320+
321+
q_h, k_h, v_h = to_heads(q_local), to_heads(k_local), to_heads(v_local)
298322

299-
# Softmax and weighted sum
300-
weights = jax.nn.softmax(scores, axis=-1)
301-
out = jnp.einsum("hqk,hkd->hqd", weights, v_h.astype(jnp.float32))
323+
# Compute scores in float32
324+
scores = jnp.einsum("hqd,hkd->hqk", q_h.astype(jnp.float32), k_h.astype(jnp.float32))
302325

303-
return jnp.transpose(out, (1, 0, 2)).reshape(num_tokens, hidden)
326+
# Add position bias for self-attention (T5-specific)
327+
if not is_cross_attn and has_rel_bias:
328+
pos_bias = self._compute_position_bias(
329+
q_lens_local, q_local.shape[0], k_local.shape[0], rel_weight
330+
)
331+
scores = scores + pos_bias.astype(jnp.float32)
304332

305-
def _compute_position_bias(self, seq_lens, q_len, k_len):
333+
# Apply block_diagonal_mask
334+
scores = _apply_block_diagonal_mask(
335+
scores, q_lens_local, kv_lens_local, is_causal=is_causal
336+
)
337+
338+
# Softmax and weighted sum
339+
weights = jax.nn.softmax(scores, axis=-1)
340+
out = jnp.einsum("hqk,hkd->hqd", weights, v_h.astype(jnp.float32))
341+
342+
return jnp.transpose(out, (1, 0, 2)).reshape(q_local.shape[0], local_hidden)
343+
344+
result = jax.shard_map(
345+
_compute_attention,
346+
mesh=self.mesh,
347+
in_specs=in_specs,
348+
out_specs=out_specs,
349+
check_vma=False,
350+
)(q, k, v, q_lens, kv_lens, rel_bias_weight)
351+
352+
return result
353+
354+
def _compute_position_bias(self, seq_lens, q_len, k_len, rel_weight):
306355
"""Compute T5 position bias [heads, q_len, k_len]."""
307356
starts = jnp.cumsum(seq_lens) - seq_lens
308357
indicators = jnp.zeros(q_len, dtype=jnp.int32).at[starts].set(1)
@@ -318,7 +367,8 @@ def _compute_position_bias(self, seq_lens, q_len, k_len):
318367
num_buckets=self.num_buckets,
319368
max_distance=self.max_distance,
320369
)
321-
return jnp.transpose(self.rel_bias(buckets), (2, 0, 1))
370+
bias = rel_weight[buckets]
371+
return jnp.transpose(bias, (2, 0, 1))
322372

323373

324374
# =============================================================================
@@ -467,8 +517,9 @@ def __call__(self, forward_batch: ForwardBatch, token_to_kv_pool=None, logits_me
467517

468518
# Dummy logits for interface compatibility
469519
bs = forward_batch.seq_lens.shape[0]
470-
dummy = jnp.zeros((bs, self.config.vocab_size), dtype=self.dtype)
471-
dummy = jax.sharding.reshard(dummy, NamedSharding(self.mesh, P(None, "tensor")))
520+
dummy = jnp.zeros(
521+
(bs, self.config.vocab_size), dtype=self.dtype, out_sharding=("data", "tensor")
522+
)
472523
return LogitsProcessorOutput(next_token_logits=dummy, hidden_states=hidden), [], [], None
473524

474525

python/sgl_jax/test/mem_cache/test_kv_cache.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def test_kv_cache_update_page_size_1(self):
117117
total_tokens = 16
118118
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False)
119119

120-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=1)
120+
updated_k_cache, updated_v_cache = update_kv_cache(
121+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=1
122+
)
121123

122124
# Expected result
123125
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -132,7 +134,9 @@ def test_kv_cache_update_page_size_1_with_padding(self):
132134
total_tokens = 12
133135
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=True)
134136

135-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=1)
137+
updated_k_cache, updated_v_cache = update_kv_cache(
138+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=1
139+
)
136140

137141
# Expected result (should ignore padding tokens where loc == -1)
138142
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -148,7 +152,9 @@ def test_kv_cache_update_page_size_4(self):
148152
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False)
149153

150154
# Test with page_size=4
151-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=4)
155+
updated_k_cache, updated_v_cache = update_kv_cache(
156+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=4
157+
)
152158

153159
# Expected result
154160
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -163,7 +169,9 @@ def test_kv_cache_update_page_size_4_with_padding(self):
163169
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=True)
164170

165171
# Test with page_size=4
166-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=4)
172+
updated_k_cache, updated_v_cache = update_kv_cache(
173+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=4
174+
)
167175

168176
# Expected result (should ignore padding tokens where loc == -1)
169177
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -179,7 +187,9 @@ def test_kv_cache_update_page_size_8_contiguous(self):
179187
k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False)
180188

181189
# Test with page_size=8
182-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=8)
190+
updated_k_cache, updated_v_cache = update_kv_cache(
191+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=8
192+
)
183193

184194
# Expected result
185195
expected_k_cache, expected_v_cache = self.expected_update_kv_cache(
@@ -204,7 +214,9 @@ def test_all_padding_tokens(self):
204214
original_v_cache = v_cache.copy()
205215

206216
# Test both approaches
207-
updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=8)
217+
updated_k_cache, updated_v_cache = update_kv_cache(
218+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=8
219+
)
208220

209221
# Cache should remain unchanged since all tokens are padding
210222
self.assertTrue(jnp.allclose(updated_k_cache, original_k_cache))
@@ -266,7 +278,7 @@ def test_kv_cache_update_multiple_segments_with_padding(self):
266278
for page_size in [1, 2, 4, 8]:
267279
with self.subTest(page_size=page_size):
268280
updated_k_cache, updated_v_cache = update_kv_cache(
269-
k, v, loc, k_cache, v_cache, page_size=page_size
281+
k, v, loc, k_cache.copy(), v_cache.copy(), page_size=page_size
270282
)
271283

272284
# Expected result

python/sgl_jax/test/mem_cache/test_swa_radix_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class TestSWARadixCache(unittest.TestCase):
2727
def setUp(self):
2828
# Keep KV sizes small to make tests light-weight
2929
self.devices = jax.devices()
30-
self.mesh = Mesh([self.devices[0]], axis_names=("tensor",))
30+
# Create mesh with both 'data' and 'tensor' axes for DP compatibility
31+
self.mesh = Mesh(np.array([self.devices[0]]).reshape(1, 1), axis_names=("data", "tensor"))
3132

3233
# Small buffers to avoid heavy allocations
3334
self.kv_head_num = 1

python/sgl_jax/test/test_flashattention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def create_test_data(
140140
causal=True,
141141
input_ids=None,
142142
model_config=None,
143-
max_total_token_size=710016,
143+
max_total_token_size=100000,
144144
):
145145
"""Create a real ForwardBatch for testing."""
146146
assert mode in ["prefill", "decode"]

0 commit comments

Comments
 (0)