Skip to content

Commit ed417aa

Browse files
committed
per code-rabbit comment
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent 87787a8 commit ed417aa

1 file changed

Lines changed: 4 additions & 5 deletions

File tree

flashinfer/gdn_kernels/blackwell/gdn_prefill.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,10 @@ def chunk_gated_delta_rule_sm100(
221221
workspace_size = GatedDeltaNetChunkedKernel.get_workspace_size(
222222
num_sm, B, HQ, HV, True
223223
)
224-
if "workspace" not in cache or cache["workspace"].size(0) < workspace_size:
225-
cache["workspace"] = torch.empty(
226-
workspace_size, dtype=torch.int8, device=q.device
227-
)
228-
workspace = cache["workspace"]
224+
ws_key = f"workspace_{q.device.index}"
225+
if ws_key not in cache or cache[ws_key].size(0) < workspace_size:
226+
cache[ws_key] = torch.empty(workspace_size, dtype=torch.int8, device=q.device)
227+
workspace = cache[ws_key]
229228

230229
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
231230
compiled(

0 commit comments

Comments
 (0)