Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion benchmarks/bench_gdn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,18 @@ def bench_fi(endpoints, h_qk, h_v, d, warmup, iters):
state_out = torch.zeros_like(h0)

fn = lambda: chunk_gated_delta_rule(
q, k, v, g, beta, None, h0, True, cu_seqlens, False, None, state_out
q,
k,
v,
g,
beta,
scale=None,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=False,
output=None,
output_state=state_out,
)
Comment thread
Observer007 marked this conversation as resolved.
times = bench_gpu_time(
fn, enable_cupti=True, dry_run_iters=warmup, repeat_iters=iters
Expand Down
48 changes: 42 additions & 6 deletions flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(
use_initial_state: bool,
store_final_state: bool = True,
enable_checkpoints: bool = False,
is_log_gate: bool = False,
is_initial_state_pool: bool = False,
is_persistent: bool = True,
):
self.io_dtype = io_dtype
Expand All @@ -224,7 +226,15 @@ def __init__(
self.use_initial_state = use_initial_state
self.store_final_state = store_final_state
self.enable_checkpoints = enable_checkpoints
self.is_log_gate = is_log_gate
self.is_initial_state_pool = is_initial_state_pool
self.is_persistent = is_persistent
self.log2_e = math.log2(math.e)

if self.is_initial_state_pool:
assert self.use_initial_state, (
"is_initial_state_pool requires use_initial_state"
)

# ------------------------------------------------------------------
# Warp assignments (12 warps total)
Expand Down Expand Up @@ -345,6 +355,8 @@ def can_implement(
mma_tiler_qs,
mma_tiler_qkv,
mma_tiler_kv,
use_initial_state=False,
is_initial_state_pool=False,
):
"""Raise CantImplementError if this configuration is not supported."""
if io_dtype not in [cutlass.Float16, cutlass.BFloat16]:
Expand All @@ -371,6 +383,11 @@ def can_implement(
raise testing.CantImplementError(
f"mma_tiler_kv={mma_tiler_kv} not supported; only (128, 128, 64) is supported"
)
if is_initial_state_pool:
if not use_initial_state:
raise testing.CantImplementError(
"is_initial_state_pool requires use_initial_state"
)

# -----------------------------------------------------------------------
# Host entry point
Expand All @@ -387,6 +404,7 @@ def __call__(
o: cute.Tensor,
cu_seqlens: cute.Tensor,
s_in: Optional[cute.Tensor],
s_in_indices: Optional[cute.Tensor],
s_out: Optional[cute.Tensor],
s_checkpoints: Optional[cute.Tensor],
cu_checkpoints: Optional[cute.Tensor],
Expand All @@ -395,6 +413,11 @@ def __call__(
tensormap_workspace: cute.Tensor,
stream: cuda.CUstream,
):
if cutlass.const_expr(self.is_initial_state_pool):
assert s_in_indices is not None and s_in is not None, (
"s_in_indices and s_in must be provided if is_initial_state_pool is True"
)

# chunk size
self.b_t = 64
h_q = q.shape[1]
Expand Down Expand Up @@ -810,6 +833,7 @@ class SharedStorage:
tma_o,
cu_seqlens,
s_in,
s_in_indices,
s_out,
s_checkpoints,
cu_checkpoints,
Expand Down Expand Up @@ -868,6 +892,8 @@ def kernel(
cu_seqlens: cute.Tensor,
# initial state (fp32) from GMEM; None if not used
mS_init: Optional[cute.Tensor],
# initial state indices (int32) from GMEM for pool mode; None otherwise
mS_init_indices: Optional[cute.Tensor],
# final state output (fp32) to GMEM; None if not stored
mS_out: Optional[cute.Tensor],
mS_checkpoints: Optional[cute.Tensor],
Expand Down Expand Up @@ -1242,6 +1268,7 @@ def _cg(num_threads):
kv_acc_producer = self._load_initial_state(
tidx,
mS_init,
mS_init_indices,
head_idx,
batch_idx,
tmem_ptr,
Expand Down Expand Up @@ -1888,15 +1915,21 @@ def load_gate_beta_warp(

# --- Gate load ---
if cutlass.const_expr(is_last_tile):
# OOB neutral: 1.0 -> log2 ~= 0.0 (no decay contribution)
tGrGate.fill(1.0)
# OOB neutral: 1.0 -> log2 ~= 0.0 (no decay contribution).
# When is_log_gate, gate is already in natural-log space; neutral is 0.
tGrGate.fill(0.0 if self.is_log_gate else 1.0)
cute.copy(tiled_copy_gate_g2r, tGgGate, tGrGate, pred=tGpGate)
else:
cute.copy(tiled_copy_gate_g2r, tGgGate, tGrGate)

# --- log2 + warp-wide inclusive prefix sum + SMEM store (always) ---
for i in range(cute.size(tGrGate)):
tGrGate[i] = cute.math.log2(tGrGate[i] + 1e-10, fastmath=True)
if cutlass.const_expr(not self.is_log_gate):
tGrGate[i] = cute.math.log2(tGrGate[i] + 1e-10, fastmath=True)
else:
# If gate is already in natural log, convert to log2
# (log2(x) = ln(x) * log2(e)).
tGrGate[i] = tGrGate[i] * self.log2_e
for offset in [1, 2, 4, 8, 16]:
for col in range(cute.size(tGrGate)):
n = cute.arch.shuffle_sync_up(
Expand Down Expand Up @@ -2935,6 +2968,7 @@ def _load_initial_state(
self,
tidx,
mS_init,
mS_init_indices,
head_idx,
batch_idx,
tmem_ptr,
Expand Down Expand Up @@ -2973,8 +3007,11 @@ def _load_initial_state(
tRT_tCrState = cute.make_rmem_tensor_like(tRT_tCcState, self.acc_dtype)
tGR_tCrState = cute.make_rmem_tensor_like(tRT_tCcState, self.state_dtype)

s_in_index = (
mS_init_indices[batch_idx] if self.is_initial_state_pool else batch_idx
)
Comment thread
Observer007 marked this conversation as resolved.
gS_init = cute.flat_divide(
mS_init[None, None, head_idx, batch_idx],
mS_init[None, None, head_idx, s_in_index],
(self.mma_tiler_kv[0], self.mma_tiler_kv[1]),
)[None, None, 0, 0]
tGR_tCgState = thr_state_r2t.partition_S(gS_init)
Expand Down Expand Up @@ -3333,8 +3370,7 @@ def compute_group_1(

gate_handle = load_gate_consumer.wait_and_advance()

max_coord = tTR_tCcShared[cute.size(tTR_tCcShared) - 1]
cumprod_total = sCumprod[max_coord[1], 0, gate_handle.index]
cumprod_total = sCumprod[sCumprod.shape[0] - 1, 0, gate_handle.index]

valid_state = not is_first_chunk or self.use_initial_state
if cutlass.const_expr(valid_state):
Expand Down
30 changes: 29 additions & 1 deletion flashinfer/gdn_kernels/blackwell/gdn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _get_compiled_cache(
use_initial_state: bool,
store_final_state: bool,
enable_checkpoints: bool,
is_log_gate: bool,
is_initial_state_pool: bool,
):
"""Return a mutable dict that lazily stores the compiled kernel."""
return {}
Expand Down Expand Up @@ -97,11 +99,13 @@ def chunk_gated_delta_rule_sm100(
output: torch.Tensor,
cu_seqlens: torch.Tensor,
initial_state: Optional[torch.Tensor],
initial_state_indices: Optional[torch.Tensor],
output_state: Optional[torch.Tensor],
scale: float,
checkpoint_every_n_tokens: int = 0,
cu_checkpoints: Optional[torch.Tensor] = None,
output_checkpoints: Optional[torch.Tensor] = None,
is_log_gate: bool = False,
) -> None:
"""Execute the Blackwell chunked GDN prefill kernel.

Expand All @@ -115,18 +119,30 @@ def chunk_gated_delta_rule_sm100(
beta: ``(total_tokens, HO)`` float32, update gate
output: ``(total_tokens, HO, DK)`` float16/bfloat16, pre-allocated
cu_seqlens: ``(num_seqs + 1,)`` int32
initial_state: ``(num_seqs, HO, DK, DK)`` float32/bfloat16, or None
initial_state: ``(pool_size, HO, DK, DK)`` float32/bfloat16, or None.
When ``initial_state_indices`` is None, ``pool_size == num_seqs``
and the i-th sequence reads its initial state from row i. When
``initial_state_indices`` is not None (pool mode), each sequence
reads its initial state from ``initial_state[initial_state_indices[i]]``.
initial_state_indices: ``(num_seqs,)`` int32, or None.
Pool indices selecting which row of ``initial_state`` each sequence
uses. Requires ``initial_state`` to be provided.
output_state: ``(num_seqs, HO, DK, DK)`` float32/bfloat16, or None
scale: attention scale factor (must not be 0)
checkpoint_every_n_tokens: store intermediate state every N tokens (0 = disabled)
cu_checkpoints: ``(num_seqs + 1,)`` int32, cumulative checkpoint counts
output_checkpoints: ``(total_checkpoints, HO, DK, DK)`` float32/bfloat16, or None
is_log_gate: if True, ``gate`` is already in natural-log space
(i.e. caller provides ``log(gate)``). Default False.
"""
HQ = q.size(1)
HV = v.size(1)
DK = q.size(2)
is_GQA = HQ >= HV
use_initial_state = initial_state is not None
is_initial_state_pool = initial_state_indices is not None
if is_initial_state_pool and not use_initial_state:
raise ValueError("initial_state_indices requires initial_state to be provided")
store_final_state = output_state is not None
enable_checkpoints = checkpoint_every_n_tokens > 0
io_dtype = _cutlass_io_dtype(q.dtype)
Expand All @@ -153,6 +169,8 @@ def chunk_gated_delta_rule_sm100(
use_initial_state,
store_final_state,
enable_checkpoints,
is_log_gate,
is_initial_state_pool,
)

if "compiled" not in cache:
Expand All @@ -175,6 +193,8 @@ def chunk_gated_delta_rule_sm100(
use_initial_state=use_initial_state,
store_final_state=store_final_state,
enable_checkpoints=enable_checkpoints,
is_log_gate=is_log_gate,
is_initial_state_pool=is_initial_state_pool,
is_persistent=True,
)

Expand Down Expand Up @@ -214,6 +234,12 @@ def chunk_gated_delta_rule_sm100(
mode=3, stride_order=(0, 1, 2, 3), divisibility=DK
)

s_in_indices_cute = None
if is_initial_state_pool:
s_in_indices_cute = from_dlpack(
initial_state_indices, assumed_align=4
).mark_layout_dynamic()

s_out_cute = None
if store_final_state:
s_out_cute = from_dlpack(_output_state, assumed_align=16)
Expand Down Expand Up @@ -250,6 +276,7 @@ def chunk_gated_delta_rule_sm100(
o_cute,
cu_seqlens_cute,
s_in_cute,
s_in_indices_cute,
s_out_cute,
s_checkpoints_cute,
cu_checkpoints_cute,
Expand Down Expand Up @@ -285,6 +312,7 @@ def chunk_gated_delta_rule_sm100(
output,
cu_seqlens,
_initial_state,
initial_state_indices if is_initial_state_pool else None,
_output_state,
output_checkpoints,
cu_checkpoints,
Expand Down
40 changes: 38 additions & 2 deletions flashinfer/gdn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def chunk_gated_delta_rule(
beta: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
initial_state: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this parameter makes more sense to insert right after initial_state; but can we double check if there any backwards compatibility concerns (breaking positional ordering) here upstream? Not sure if perplexity has picked this up yet.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaixih @YAMY1234 @hlu1 do you know if SGLang has used the latest gdn_prefill.py interface? (if so, this would be an interface breakage)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we are aware of this change. Actually, we requested this improvement during sgl-project/sglang#22921.

output_final_state: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
Expand All @@ -112,6 +113,7 @@ def chunk_gated_delta_rule(
state_checkpoints: Optional[torch.Tensor] = None,
checkpoint_cu_starts: Optional[torch.Tensor] = None,
checkpoint_every_n_tokens: int = 0,
is_log_gate: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Chunked Gated Delta Rule (GDN) attention for prefill.

Expand Down Expand Up @@ -140,8 +142,17 @@ def chunk_gated_delta_rule(
Scale factor for the attention scores.
If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``.
initial_state (Optional[torch.Tensor]):
Initial KV state of shape ``[num_seqs, num_sab_heads, head_size, head_size]``.
Must be float32. If None, starts from zero state. Default: ``None``.
Initial KV state of shape ``[pool_size, num_sab_heads, head_size, head_size]``.
Must be float32 or bfloat16. If None, starts from zero state.
``pool_size`` equals ``num_seqs`` unless ``initial_state_indices`` is
provided, in which case ``pool_size`` can be anything and each
sequence ``i`` reads its initial state from row
``initial_state_indices[i]``. Default: ``None``.
initial_state_indices (Optional[torch.Tensor]):
Optional int32 pool indices of shape ``[num_seqs]``. When provided,
the i-th sequence reads its initial state from
``initial_state[initial_state_indices[i]]``. Requires
``initial_state`` to be provided. SM100-only. Default: ``None``.
output_final_state (bool):
Whether to output the final state. Default: ``False``.
cu_seqlens (torch.Tensor):
Expand Down Expand Up @@ -171,6 +182,12 @@ def chunk_gated_delta_rule(
checkpoint_every_n_tokens (int):
Store intermediate state every N tokens. Must be a multiple of
the chunk size (64). 0 means disabled (default).
is_log_gate (bool):
If ``True``, the ``g`` tensor is interpreted as already being in
natural-log space (i.e. the caller has applied ``torch.log(alpha)``).
If ``False`` (default), ``g`` is taken to be in the standard
multiplicative space and the kernel applies ``log2`` internally.
SM100-only.

Returns:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Expand Down Expand Up @@ -212,6 +229,9 @@ def chunk_gated_delta_rule(
"when checkpoint_every_n_tokens == 0"
)

if initial_state_indices is not None and initial_state is None:
raise ValueError("initial_state_indices requires initial_state to be provided")

assert cu_seqlens is not None, "cu_seqlens is required for varlen mode"

num_seqs = cu_seqlens.size(0) - 1
Expand Down Expand Up @@ -314,6 +334,12 @@ def chunk_gated_delta_rule(
if checkpoint_every_n_tokens > 0 and checkpoint_cu_starts is not None:
_cu_checkpoints = checkpoint_cu_starts.to(torch.int32)

_initial_state_indices = (
initial_state_indices.to(torch.int32)
if initial_state_indices is not None
else None
)

chunk_gated_delta_rule_sm100(
q,
k,
Expand All @@ -323,14 +349,24 @@ def chunk_gated_delta_rule(
output,
cu_seqlens.to(torch.int32),
initial_state,
_initial_state_indices,
output_state if output_final_state else None,
_scale,
checkpoint_every_n_tokens=checkpoint_every_n_tokens,
cu_checkpoints=_cu_checkpoints,
output_checkpoints=state_checkpoints,
is_log_gate=is_log_gate,
)
else:
# SM90 Hopper path (C++ JIT kernel)
if is_log_gate:
raise NotImplementedError(
"is_log_gate=True is only supported on SM100 (Blackwell)"
)
if initial_state_indices is not None:
raise NotImplementedError(
"initial_state_indices is only supported on SM100 (Blackwell)"
)
workspace_size = get_device_sm_count(device) * 128
workspace_buffer = _get_cache_buf(
"gdn_prefill_workspace", workspace_size, device
Expand Down
17 changes: 14 additions & 3 deletions tests/gdn/reference_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def blockwise_delta_rule(
block_size: int = 32,
scale_factor=1.0,
state_dtype: torch.dtype = torch.float32,
initial_state: torch.Tensor | None = None, # [pool_size, H, head_size, head_size]
initial_state_indices: torch.Tensor | None = None, # [num_seqs] int32 pool indices
# intermediate_outputs = None, # debug output
) -> torch.Tensor:
total_seqlen = q.size(0)
Expand Down Expand Up @@ -401,9 +403,18 @@ def blockwise_delta_rule(
for seq_idx, seq_start in enumerate(seq_offset[:-1]):
seq_end = seq_offset[seq_idx + 1]
blk_offset = seq_start
state_HKV = torch.zeros(
(num_sab_heads, head_size, head_size), dtype=state_dtype, device=q.device
)
if initial_state is not None:
if initial_state_indices is not None:
pool_idx = int(initial_state_indices[seq_idx].item())
else:
pool_idx = seq_idx
state_HKV = initial_state[pool_idx].to(state_dtype).to(q.device)
else:
state_HKV = torch.zeros(
(num_sab_heads, head_size, head_size),
dtype=state_dtype,
device=q.device,
)
while blk_offset < seq_end:
is_full_block = seq_end - blk_offset >= block_size
valid_len = block_size if is_full_block else seq_end - blk_offset
Expand Down
Loading
Loading