@@ -51,7 +51,12 @@ def _kda_workspace(q, T, Hv, K, V, cu_seqlens):
5151 import torch as _t
5252
5353 dev = q .device
54- key = (Hv , K , V , dev , q .dtype )
54+ # Key by the current CUDA stream too: the scratch is process-global and
55+ # mutable, so two KDA forwards running concurrently on different streams
56+ # (e.g. two-batch overlap) must not share buffers. Within one forward all
57+ # KDA layers run on the same stream -> same key -> the reuse benefit holds.
58+ stream = _t .cuda .current_stream (device = dev ).cuda_stream
59+ key = (Hv , K , V , dev , q .dtype , stream )
5560 ws = _KDA_WS .get (key )
5661
5762 # metadata: recompute only when cu_seqlens changes (object identity -> no
@@ -90,9 +95,11 @@ def _kda_workspace(q, T, Hv, K, V, cu_seqlens):
9095 eye = ws ["eye" ]
9196 hw = max (ws ["eye_hw" ], T )
9297 eye [:hw ].zero_ ()
93- tok = _t .arange (T , device = dev )
94- seq_of = _t .searchsorted (cu_seqlens .long (), tok , right = True ) - 1
95- pos = (tok - cu_seqlens .long ()[seq_of ]) % 64
98+ # Match cu_seqlens' dtype (typically int32) so searchsorted/indexing avoid
99+ # the int64 casts, while staying correct if cu_seqlens is passed as int64.
100+ tok = _t .arange (T , device = dev , dtype = cu_seqlens .dtype )
101+ seq_of = _t .searchsorted (cu_seqlens , tok , right = True ) - 1
102+ pos = (tok - cu_seqlens [seq_of ]) % 64
96103 eye [tok , :, pos ] = 1.0
97104 ws ["eye_hw" ] = T
98105 ws ["cu" ] = cu_seqlens
0 commit comments