We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 87787a8 commit ed417aaCopy full SHA for ed417aa
1 file changed
flashinfer/gdn_kernels/blackwell/gdn_prefill.py
@@ -221,11 +221,10 @@ def chunk_gated_delta_rule_sm100(
221
workspace_size = GatedDeltaNetChunkedKernel.get_workspace_size(
222
num_sm, B, HQ, HV, True
223
)
224
- if "workspace" not in cache or cache["workspace"].size(0) < workspace_size:
225
- cache["workspace"] = torch.empty(
226
- workspace_size, dtype=torch.int8, device=q.device
227
- )
228
- workspace = cache["workspace"]
+ ws_key = f"workspace_{q.device.index}"
+ if ws_key not in cache or cache[ws_key].size(0) < workspace_size:
+ cache[ws_key] = torch.empty(workspace_size, dtype=torch.int8, device=q.device)
+ workspace = cache[ws_key]
229
230
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
231
compiled(
0 commit comments