Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions tensorrt_llm/_torch/modules/fla/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def chunk_gated_delta_rule_fwd(
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
initial_state_indices: Optional[torch.Tensor],
inplace_indexed_state_update: bool,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
Expand All @@ -54,7 +56,9 @@ def chunk_gated_delta_rule_fwd(
u=u,
g=g,
initial_state=initial_state,
initial_state_indices=initial_state_indices,
output_final_state=output_final_state,
inplace_indexed_state_update=inplace_indexed_state_update,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
Expand Down Expand Up @@ -86,6 +90,8 @@ def forward(
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
initial_state_indices: Optional[torch.Tensor],
inplace_indexed_state_update: bool,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
Expand All @@ -102,6 +108,8 @@ def forward(
beta=beta,
scale=scale,
initial_state=initial_state,
initial_state_indices=initial_state_indices,
inplace_indexed_state_update=inplace_indexed_state_update,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
Expand All @@ -117,6 +125,8 @@ def chunk_gated_delta_rule(
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
initial_state_indices: Optional[torch.Tensor] = None,
inplace_indexed_state_update: bool = False,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
Expand All @@ -141,6 +151,13 @@ def chunk_gated_delta_rule(
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
initial_state_indices (Optional[torch.Tensor]):
Optional state-pool indices of shape `[N]` selecting the slots to
read from `initial_state`.
inplace_indexed_state_update (Optional[bool]):
Explicit opt-in for writing indexed final states back into
`initial_state` in-place. Callers are responsible for ensuring the
selected slots are safe to update without aliasing races.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Expand Down Expand Up @@ -211,12 +228,18 @@ def chunk_gated_delta_rule(
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
num_sequences = len(cu_seqlens) - 1
if initial_state_indices is not None:
if initial_state_indices.shape[0] != num_sequences:
raise ValueError(
f"The number of initial-state indices is expected to be equal to the number of input "
f"sequences, i.e., {num_sequences} rather than {initial_state_indices.shape[0]}."
)
elif initial_state is not None and initial_state.shape[
0] != num_sequences:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
f"i.e., {num_sequences} rather than {initial_state.shape[0]}.")
if scale is None:
scale = k.shape[-1]**-0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
Expand All @@ -227,6 +250,8 @@ def chunk_gated_delta_rule(
beta,
scale,
initial_state,
initial_state_indices,
inplace_indexed_state_update,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
Expand Down
29 changes: 26 additions & 3 deletions tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
@triton.heuristics({
"USE_G": lambda args: args["g"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"USE_INDEXED_STATE": lambda args: args["h0_i"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
Expand All @@ -42,6 +43,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
g,
h,
h0,
h0_i,
ht,
cu_seqlens,
chunk_offsets,
Expand All @@ -54,6 +56,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_INDEXED_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
Expand Down Expand Up @@ -91,10 +94,16 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
stride_h = H * K * V
stride_k = Hg * K
stride_w = H * K
if USE_INDEXED_STATE:
state_index = tl.load(h0_i + i_n).to(tl.int64)
h0 = h0 + state_index * stride_h
ht = h0
if USE_INITIAL_STATE:
h0 = h0 + i_nh * K * V
h0 = h0 + ((i_h if USE_INDEXED_STATE else i_nh) * K * V)
if STORE_FINAL_STATE:
ht = ht + i_nh * K * V
elif USE_INDEXED_STATE:
ht = ht + i_h * K * V

# load initial state
if USE_INITIAL_STATE:
Expand Down Expand Up @@ -209,7 +218,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4 += tl.dot(b_k, b_v_new)

# epilogue
if STORE_FINAL_STATE:
if STORE_FINAL_STATE or USE_INDEXED_STATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV),
(1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
Expand Down Expand Up @@ -239,7 +248,9 @@ def chunk_gated_delta_rule_fwd_h(
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
output_final_state: bool = False,
inplace_indexed_state_update: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
Expand All @@ -262,8 +273,14 @@ def chunk_gated_delta_rule_fwd_h(
assert K <= 256, "current kernel does not support head dimension larger than 256."

h = k.new_empty(B, NT, H, K, V)
use_indexed_state = initial_state is not None and initial_state_indices is not None
if use_indexed_state and not inplace_indexed_state_update:
raise ValueError(
"Indexed chunk state updates require inplace_indexed_state_update=True."
)
store_final_state_in_kernel = output_final_state and not use_indexed_state
final_state = (k.new_empty(N, H, K, V, dtype=torch.float32)
if output_final_state else None)
if store_final_state_in_kernel else None)

v_new = torch.empty_like(u) if save_new_value else None

Expand All @@ -278,6 +295,7 @@ def grid(meta):
g=g,
h=h,
h0=initial_state,
h0_i=initial_state_indices,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
Expand All @@ -291,4 +309,9 @@ def grid(meta):
num_warps=4,
num_stages=2,
)
if output_final_state and use_indexed_state:
# The indexed kernel path updates h0 in-place, so returning
# the final state means gathering those updated slots back out.
final_state = initial_state.index_select(
0, initial_state_indices.to(torch.long))
return h, v_new, final_state
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def create_strategy(
payload_in_workspace: bool = False,
alltoall_result_do_sum: bool = True,
use_flashinfer: bool = False,
hidden_size: Optional[int] = None,
) -> Optional[Communication]:
"""
Create the best communication method for the given configuration
Expand All @@ -78,6 +79,9 @@ def create_strategy(
expert_size_per_partition: Number of experts per partition (required for DeepEP)
payload_in_workspace: If True, final_hidden_states is already in workspace (for NVLinkOneSided)
alltoall_result_do_sum: If True, sum the alltoall results (for NVLinkTwoSided)
hidden_size: Actual MoE activation dimension (the A2A payload width).
For latent-MoE models this is moe_latent_size, not pretrained_config.hidden_size.
Falls back to pretrained_config.hidden_size when not provided.
# TODO: Need a way to indicate whether EPLB is enabled.

Returns:
Expand All @@ -89,7 +93,8 @@ def create_strategy(
"""
# Extract parameters from model_config
mapping = model_config.mapping
hidden_size = model_config.pretrained_config.hidden_size
if hidden_size is None:
hidden_size = model_config.pretrained_config.hidden_size
act_dtype = model_config.torch_dtype
quant_config = model_config.quant_config
max_num_tokens = model_config.max_num_tokens
Expand Down Expand Up @@ -120,6 +125,7 @@ def create_strategy(
payload_in_workspace,
alltoall_result_do_sum,
use_flashinfer,
hidden_size=hidden_size,
)

# Auto-selection: Try strategies in priority order using try-catch
Expand Down Expand Up @@ -218,6 +224,7 @@ def _create_forced_method(
payload_in_workspace: bool,
alltoall_result_do_sum: bool,
use_flashinfer: bool,
hidden_size: Optional[int] = None,
) -> Communication:
"""
Create a specific method (for debugging/testing)
Expand All @@ -228,7 +235,8 @@ def _create_forced_method(
"""
# Extract parameters from model_config
mapping = model_config.mapping
hidden_size = model_config.pretrained_config.hidden_size
if hidden_size is None:
hidden_size = model_config.pretrained_config.hidden_size
act_dtype = model_config.torch_dtype
quant_config = model_config.quant_config
max_num_tokens = model_config.max_num_tokens
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def _create_comm_strategy_auto(self) -> Communication:
# Keep updated with more supported backends.
alltoall_result_do_sum=True,
use_flashinfer=self.use_flashinfer,
hidden_size=self.hidden_size,
)

def forward_impl(
Expand Down
14 changes: 7 additions & 7 deletions tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,22 +710,22 @@ def forward_extend(
g = g.unsqueeze(0)
beta = beta.unsqueeze(0)

recurrent_state = ssm_states[cache_indices]

core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
core_attn_out, _ = chunk_gated_delta_rule(
q=query,
k=key,
v=value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=True,
initial_state=ssm_states,
initial_state_indices=cache_indices,
# This path writes recurrent state directly back into the shared
# pool; callers **must** ensure cache_indices do not alias live slots.
inplace_indexed_state_update=True,
output_final_state=False,
cu_seqlens=query_start_loc_long,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False)
ssm_states[cache_indices] = last_recurrent_state

return core_attn_out

Expand Down
14 changes: 11 additions & 3 deletions tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
model: torch.nn.Module,
lora_model_config: Optional[LoraModelConfig],
device: str = "cuda",
max_tokens_per_seq: int = 1,
):
"""
Initialize the CUDA Graph LoRA manager.
Expand All @@ -41,12 +42,14 @@ def __init__(
model: Model to get layerwise LoRA info
lora_model_config: LoRA model configuration
device: Device to allocate tensors on
max_tokens_per_seq: Maximum number of tokens per sequence (>1 for spec decode)
"""
self.max_lora_size = max_lora_size
self.max_batch_size = max_batch_size
self.max_lora_rank = max_lora_rank
self.device = device

self.max_tokens_per_seq = max_tokens_per_seq
self.adapter_slot_manager = AdapterSlotManager(max_lora_size)
self.lora_model_config = lora_model_config
lora_target_modules = lora_model_config.lora_target_modules
Expand Down Expand Up @@ -74,6 +77,7 @@ def __init__(
max_rank=self.max_lora_rank,
layer_info=self.layer_info,
device=self.device,
max_tokens_per_seq=self.max_tokens_per_seq,
)

def _initialize_from_model(self, model: torch.nn.Module):
Expand Down Expand Up @@ -127,14 +131,16 @@ def prepare_cuda_graph_lora_params(
scheduled_requests: "ScheduledRequests",
attn_metadata: "AttentionMetadata",
peft_cache_manager: PeftCacheManager,
tokens_per_seq: int = 1,
) -> Optional[Dict]:
"""
Prepare LoRA parameters from scheduled requests.

Args:
scheduled_requests: The scheduled requests for the current batch
attn_metadata: Attention metadata containing batch information
peft_table: PEFT table from cache manager mapping task_id to layer-module-configs
peft_cache_manager: PEFT cache manager
tokens_per_seq: Number of tokens per sequence (for spec decode > 1)

Returns:
LoRA parameters dictionary.
Expand All @@ -151,7 +157,7 @@ def prepare_cuda_graph_lora_params(
request_slot_ids = self.adapter_slot_manager.update_slots(request_list, peft_cache_manager)

cuda_graph_lora_params = self.cuda_graph_lora_params
cuda_graph_lora_params.update_sorted_indices(request_slot_ids)
cuda_graph_lora_params.update_sorted_indices(request_slot_ids, tokens_per_seq)

# Get current slot to task mapping
slot2task = self.adapter_slot_manager.get_slot_to_task_mapping()
Expand All @@ -162,7 +168,9 @@ def prepare_cuda_graph_lora_params(
self.adapter_slot_manager.reset_slots_changed()

# Update GEMM sizes and prefix sums using batch
cuda_graph_lora_params.update_slots_params(batch_slot_ids=request_slot_ids)
cuda_graph_lora_params.update_slots_params(
batch_slot_ids=request_slot_ids, tokens_per_seq=tokens_per_seq
)

lora_params = {
"cuda_graph_params": cuda_graph_lora_params,
Expand Down
Loading
Loading