Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/sglang/srt/layers/attention/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ def forward(
layer_cache, MambaPool.SpeculativeState
), "layer_cache must be SpeculativeState for speculative decoding"
draft_token_num = metadata.draft_token_num
self.intermediate_state_indices = torch.arange(
num_decodes, dtype=torch.int32, device=state_indices_tensor_d.device
)

# Reshape for batch processing
hidden_states_B_C_d_reshaped = hidden_states_B_C_d.view(
Expand All @@ -548,6 +551,7 @@ def forward(
self.activation,
conv_state_indices=state_indices_tensor_d[:num_decodes],
intermediate_conv_window=layer_cache.intermediate_conv_window[0],
intermediate_state_indices=self.intermediate_state_indices,
retrieve_next_token=metadata.retrieve_next_token,
retrieve_next_sibling=metadata.retrieve_next_sibling,
retrieve_parent_token=metadata.retrieve_parent_token,
Expand Down Expand Up @@ -621,6 +625,7 @@ def forward(
intermediate_states_buffer=layer_cache.intermediate_ssm,
cache_steps=draft_token_num,
retrieve_parent_token=metadata.retrieve_parent_token,
intermediate_state_indices=self.intermediate_state_indices,
)
else:
selective_state_update(
Expand Down
21 changes: 19 additions & 2 deletions python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def softplus(dt):
is not None
}
)
@triton.heuristics(
{
"HAS_INTERMEDIATE_STATE_INDICES": lambda args: args[
"intermediate_state_indices_ptr"
]
is not None
}
)
@triton.jit(do_not_specialize=["T"])
def _selective_scan_update_kernel(
# Pointers to matrices
Expand All @@ -74,6 +82,7 @@ def _selective_scan_update_kernel(
intermediate_states_buffer,
cache_steps,
retrieve_parent_token_ptr,
intermediate_state_indices_ptr,
# Matrix dimensions
batch,
T,
Expand Down Expand Up @@ -130,6 +139,7 @@ def _selective_scan_update_kernel(
DISABLE_STATE_UPDATE: tl.constexpr,
CACHE_INTERMEDIATE_STATES: tl.constexpr,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,
HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
Expand Down Expand Up @@ -177,7 +187,10 @@ def _selective_scan_update_kernel(

cache_idx = -1
if CACHE_INTERMEDIATE_STATES:
if HAS_STATE_BATCH_INDICES:
if HAS_INTERMEDIATE_STATE_INDICES:
intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to(tl.int64)
cache_idx = intermediate_state_idx
elif HAS_STATE_BATCH_INDICES:
cache_idx = state_batch_idx
else:
cache_idx = pid_b
Expand Down Expand Up @@ -250,7 +263,7 @@ def _selective_scan_update_kernel(
if state_batch_idx != pad_slot_id:
cache_ptr_base = (
intermediate_states_buffer
+ state_batch_idx * cache_steps * nheads * dim * dstate
+ cache_idx * cache_steps * nheads * dim * dstate
+ current_step_idx * nheads * dim * dstate
+ pid_h * dim * dstate
)
Expand Down Expand Up @@ -300,6 +313,7 @@ def selective_state_update(
intermediate_states_buffer=None,
cache_steps=None,
retrieve_parent_token=None,
intermediate_state_indices=None,
):
"""
Argument:
Expand All @@ -324,6 +338,8 @@ def selective_state_update(
intermediate_states_buffer: Buffer to cache intermediate states
cache_steps: Total number of steps in the buffer
retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention
intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations.
If provided, uses these indices instead of state_batch_indices for the buffer.
"""
if state.dim() == 3:
state = state.unsqueeze(1)
Expand Down Expand Up @@ -426,6 +442,7 @@ def selective_state_update(
intermediate_states_buffer,
cache_steps if cache_steps is not None else 0,
retrieve_parent_token,
intermediate_state_indices,
batch,
T,
nheads,
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,13 @@ def _handle_model_specific_adjustments(self):
else:
self.quantization = model_config.quantization
self.moe_runner_backend = "flashinfer_cutlass"
if not self.disable_radix_cache:

if not self.disable_radix_cache and self.speculative_algorithm is not None:
logger.warning(
"Disabling radix cache since speculative decoding for NemotronHForCausalLM is not supported with radix cache yet."
)
self.disable_radix_cache = True
elif not self.disable_radix_cache:
logger.warning(
"Disabling overlap schedule since MambaRadixCache is not compatible with "
"overlap schedule currently, try to use --disable-radix-cache if overlap schedule is necessary"
Expand Down
Loading