-
Notifications
You must be signed in to change notification settings - Fork 941
[feat] add log gate and initial state pool support in blackwell gdn prefill #3167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
f9f68fc
eec96ed
869ad04
ceaef85
bc63825
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -104,6 +104,7 @@ def chunk_gated_delta_rule( | |
| beta: Optional[torch.Tensor] = None, | ||
| scale: Optional[float] = None, | ||
| initial_state: Optional[torch.Tensor] = None, | ||
| initial_state_indices: Optional[torch.Tensor] = None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand this parameter makes more sense to insert right after initial_state; but can we double check if there any backwards compatibility concerns (breaking positional ordering) here upstream? Not sure if perplexity has picked this up yet.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we are aware of this change. Actually, we requested this improvement during sgl-project/sglang#22921. |
||
| output_final_state: bool = False, | ||
| cu_seqlens: Optional[torch.Tensor] = None, | ||
| use_qk_l2norm_in_kernel: bool = False, | ||
|
|
@@ -112,6 +113,7 @@ def chunk_gated_delta_rule( | |
| state_checkpoints: Optional[torch.Tensor] = None, | ||
| checkpoint_cu_starts: Optional[torch.Tensor] = None, | ||
| checkpoint_every_n_tokens: int = 0, | ||
| is_log_gate: bool = False, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| r"""Chunked Gated Delta Rule (GDN) attention for prefill. | ||
|
|
||
|
|
@@ -140,8 +142,17 @@ def chunk_gated_delta_rule( | |
| Scale factor for the attention scores. | ||
| If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``. | ||
| initial_state (Optional[torch.Tensor]): | ||
| Initial KV state of shape ``[num_seqs, num_sab_heads, head_size, head_size]``. | ||
| Must be float32. If None, starts from zero state. Default: ``None``. | ||
| Initial KV state of shape ``[pool_size, num_sab_heads, head_size, head_size]``. | ||
| Must be float32 or bfloat16. If None, starts from zero state. | ||
| ``pool_size`` equals ``num_seqs`` unless ``initial_state_indices`` is | ||
| provided, in which case ``pool_size`` can be anything and each | ||
| sequence ``i`` reads its initial state from row | ||
| ``initial_state_indices[i]``. Default: ``None``. | ||
| initial_state_indices (Optional[torch.Tensor]): | ||
| Optional int32 pool indices of shape ``[num_seqs]``. When provided, | ||
| the i-th sequence reads its initial state from | ||
| ``initial_state[initial_state_indices[i]]``. Requires | ||
| ``initial_state`` to be provided. SM100-only. Default: ``None``. | ||
| output_final_state (bool): | ||
| Whether to output the final state. Default: ``False``. | ||
| cu_seqlens (torch.Tensor): | ||
|
|
@@ -171,6 +182,12 @@ def chunk_gated_delta_rule( | |
| checkpoint_every_n_tokens (int): | ||
| Store intermediate state every N tokens. Must be a multiple of | ||
| the chunk size (64). 0 means disabled (default). | ||
| is_log_gate (bool): | ||
| If ``True``, the ``g`` tensor is interpreted as already being in | ||
| natural-log space (i.e. the caller has applied ``torch.log(alpha)``). | ||
| If ``False`` (default), ``g`` is taken to be in the standard | ||
| multiplicative space and the kernel applies ``log2`` internally. | ||
| SM100-only. | ||
|
|
||
| Returns: | ||
| Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
|
|
@@ -212,6 +229,9 @@ def chunk_gated_delta_rule( | |
| "when checkpoint_every_n_tokens == 0" | ||
| ) | ||
|
|
||
| if initial_state_indices is not None and initial_state is None: | ||
| raise ValueError("initial_state_indices requires initial_state to be provided") | ||
|
|
||
| assert cu_seqlens is not None, "cu_seqlens is required for varlen mode" | ||
|
|
||
| num_seqs = cu_seqlens.size(0) - 1 | ||
|
|
@@ -314,6 +334,12 @@ def chunk_gated_delta_rule( | |
| if checkpoint_every_n_tokens > 0 and checkpoint_cu_starts is not None: | ||
| _cu_checkpoints = checkpoint_cu_starts.to(torch.int32) | ||
|
|
||
| _initial_state_indices = ( | ||
| initial_state_indices.to(torch.int32) | ||
| if initial_state_indices is not None | ||
| else None | ||
| ) | ||
|
|
||
| chunk_gated_delta_rule_sm100( | ||
| q, | ||
| k, | ||
|
|
@@ -323,14 +349,24 @@ def chunk_gated_delta_rule( | |
| output, | ||
| cu_seqlens.to(torch.int32), | ||
| initial_state, | ||
| _initial_state_indices, | ||
| output_state if output_final_state else None, | ||
| _scale, | ||
| checkpoint_every_n_tokens=checkpoint_every_n_tokens, | ||
| cu_checkpoints=_cu_checkpoints, | ||
| output_checkpoints=state_checkpoints, | ||
| is_log_gate=is_log_gate, | ||
| ) | ||
| else: | ||
| # SM90 Hopper path (C++ JIT kernel) | ||
| if is_log_gate: | ||
| raise NotImplementedError( | ||
| "is_log_gate=True is only supported on SM100 (Blackwell)" | ||
| ) | ||
| if initial_state_indices is not None: | ||
| raise NotImplementedError( | ||
| "initial_state_indices is only supported on SM100 (Blackwell)" | ||
| ) | ||
| workspace_size = get_device_sm_count(device) * 128 | ||
| workspace_buffer = _get_cache_buf( | ||
| "gdn_prefill_workspace", workspace_size, device | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.