Skip to content

Commit 9ce98db

Browse files
author
luoyuan.luo
committed
Address review comments
1 parent 1840775 commit 9ce98db

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

  • python/sglang/srt/layers/attention/linear/kernels/kda_blackwell

python/sglang/srt/layers/attention/linear/kernels/kda_blackwell/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ def _kda_workspace(q, T, Hv, K, V, cu_seqlens):
5151
import torch as _t
5252

5353
dev = q.device
54-
key = (Hv, K, V, dev, q.dtype)
54+
# Key by the current CUDA stream too: the scratch is process-global and
55+
# mutable, so two KDA forwards running concurrently on different streams
56+
# (e.g. two-batch overlap) must not share buffers. Within one forward all
57+
# KDA layers run on the same stream -> same key -> the reuse benefit holds.
58+
stream = _t.cuda.current_stream(device=dev).cuda_stream
59+
key = (Hv, K, V, dev, q.dtype, stream)
5560
ws = _KDA_WS.get(key)
5661

5762
# metadata: recompute only when cu_seqlens changes (object identity -> no
@@ -90,9 +95,11 @@ def _kda_workspace(q, T, Hv, K, V, cu_seqlens):
9095
eye = ws["eye"]
9196
hw = max(ws["eye_hw"], T)
9297
eye[:hw].zero_()
93-
tok = _t.arange(T, device=dev)
94-
seq_of = _t.searchsorted(cu_seqlens.long(), tok, right=True) - 1
95-
pos = (tok - cu_seqlens.long()[seq_of]) % 64
98+
# Match cu_seqlens' dtype (typically int32) so searchsorted/indexing avoid
99+
# the int64 casts, while staying correct if cu_seqlens is passed as int64.
100+
tok = _t.arange(T, device=dev, dtype=cu_seqlens.dtype)
101+
seq_of = _t.searchsorted(cu_seqlens, tok, right=True) - 1
102+
pos = (tok - cu_seqlens[seq_of]) % 64
96103
eye[tok, :, pos] = 1.0
97104
ws["eye_hw"] = T
98105
ws["cu"] = cu_seqlens

0 commit comments

Comments
 (0)