Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flashinfer/gdn_kernels/blackwell/gdn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.utils as cutlass_utils
from cutlass.cute.runtime import from_dlpack

from flashinfer.cute_dsl.utils import get_max_active_clusters, get_num_sm
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

from .gated_delta_net_chunked import GatedDeltaNetChunkedKernel


Expand Down Expand Up @@ -157,9 +158,8 @@ def chunk_gated_delta_rule_sm100(

if "compiled" not in cache:
# --- First call: compile the kernel ---
hardware_info = cutlass_utils.HardwareInfo()
num_sm = hardware_info.get_max_active_clusters(1)
max_active_clusters = hardware_info.get_max_active_clusters(1)
num_sm = get_num_sm(q.device)
max_active_clusters = min(get_max_active_clusters(1), num_sm)
Comment thread
arpera marked this conversation as resolved.
Outdated

gdn = GatedDeltaNetChunkedKernel(
io_dtype=io_dtype,
Expand Down
Loading