Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 19 additions & 26 deletions examples/auto_deploy/model_registry/configs/qwen3.5_moe_35b.yaml
Original file line number Diff line number Diff line change
@@ -1,38 +1,31 @@
runtime: trtllm
compile_backend: torch-cudagraph
max_seq_len: 4096
attn_backend: trtllm
max_seq_len: 8192
max_num_tokens: 4096
max_batch_size: 512
world_size: 2
world_size: 4
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
enable_chunked_prefill: true
model_factory: AutoModelForCausalLM
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.95
tokens_per_block: 64
free_gpu_memory_fraction: 0.8
tokens_per_block: 32
model_kwargs:
torch_dtype: bfloat16
# text_config:
# num_hidden_layers: 6
# vision_config:
# depth: 2
transforms:
export_to_gm:
num_moe_experts_for_export: 2
fuse_gemms_mixed_children:
enabled: true
detect_sharding:
sharding_dims: ['tp','ep', 'bmm']
# use only manual config for TP sharding
sharding_source: ['manual']
manual_config:
tp_plan:
# GDN layer
"in_proj_qkv": "delta"
# attention layer
"q_proj": "colwise"
"k_proj": "colwise"
"v_proj": "colwise"
"o_proj": "rowwise"
# replicating shared experts (keep them commented out)
# "shared_expert_gate_proj": "colwise"
# "shared_expert_up_proj": "colwise"
# "shared_expert_down_proj": "rowwise"
# gating layer should be replicated as well
# "gate": "gather"
allreduce_strategy: SYMM_MEM
multi_stream_moe:
stage: compile
enabled: true
multi_stream_gemm:
stage: compile
enabled: true
gather_logits_before_lm_head:
enabled: true
42 changes: 17 additions & 25 deletions examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
runtime: trtllm
compile_backend: torch-cudagraph
max_seq_len: 2048
max_num_tokens: 2048
max_batch_size: 512
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
attn_backend: trtllm
max_seq_len: 262144
max_num_tokens: 8192
max_batch_size: 32
cuda_graph_batch_sizes: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]
world_size: 8
enable_chunked_prefill: true
model_factory: AutoModelForCausalLM
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.95
tokens_per_block: 64
enable_block_reuse: true
free_gpu_memory_fraction: 0.8
tokens_per_block: 32
model_kwargs:
torch_dtype: bfloat16
transforms:
Expand All @@ -19,21 +20,12 @@ transforms:
fuse_gemms_mixed_children:
enabled: true
detect_sharding:
sharding_dims: ['tp','ep', 'bmm']
# use only manual config for TP sharding
sharding_source: ['manual']
manual_config:
tp_plan:
# GDN layer
"in_proj_qkv": "delta"
# attention layer
"q_proj": "colwise"
"k_proj": "colwise"
"v_proj": "colwise"
"o_proj": "rowwise"
# replicating shared experts (keep them commented out)
# "shared_expert_gate_proj": "colwise"
# "shared_expert_up_proj": "colwise"
# "shared_expert_down_proj": "rowwise"
# gating layer should be replicated as well
# "gate": "gather"
allreduce_strategy: SYMM_MEM
multi_stream_moe:
stage: compile
enabled: true
multi_stream_gemm:
stage: compile
enabled: true
gather_logits_before_lm_head:
enabled: true
18 changes: 16 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ transforms:
run_shape_prop: true
match_l2norm_pattern:
stage: pattern_matcher
match_moe_routing_pattern:
stage: pattern_matcher
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################
Expand Down Expand Up @@ -89,15 +91,20 @@ transforms:
# proceeds normally.
match_swiglu_pattern:
stage: pattern_matcher
enabled: false
enabled: true
match_nvfp4_swiglu_pattern:
stage: pattern_matcher
requires_shape_prop: true
enabled: false
enabled: true
match_finegrained_fp8_swiglu_pattern:
stage: pattern_matcher
requires_shape_prop: true
enabled: true
quantize_fp8_moe:
stage: pattern_matcher
quantize_nvfp4_moe:
stage: pattern_matcher
run_shape_prop: true
quantize_mxfp4_moe:
stage: pattern_matcher
detect_hidden_states_for_capture:
Expand Down Expand Up @@ -156,6 +163,8 @@ transforms:
enabled: true
fuse_nvfp4_swiglu:
stage: post_load_fusion
fuse_finegrained_fp8_swiglu:
stage: post_load_fusion
fuse_finegrained_fp8_linear:
stage: post_load_fusion
backend: trtllm
Expand Down Expand Up @@ -185,6 +194,8 @@ transforms:
rmsnorm_backend: flashinfer
gated_rmsnorm_backend: triton
requires_shape_prop: true
fuse_gdn_gating:
stage: post_load_fusion
fuse_l2norm:
stage: post_load_fusion
backend: fla
Expand Down Expand Up @@ -249,6 +260,9 @@ transforms:
multi_stream_mla_attn:
stage: compile
enabled: false
multi_stream_gemm:
stage: compile
enabled: false
compile_model:
stage: compile
expect_mem_change: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def reset(self, device: torch.device) -> None:

# NOTE (lucaslie): avoid OOM for many cudagraphs,
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3686
self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8)
self.workspace_buffer = torch.empty(1024 * 1024 * 1024, device=device, dtype=torch.uint8)

# NOTE (lucaslie): flashinfer fa3 backend has accuracy issue + illegal memory access issues
# on H100 PCIe, see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
Expand Down
15 changes: 15 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ class SequenceInfo:
Total sequence length including cached tokens for each sequence (input_pos + seq_len).
- use_initial_states: [bool_0, bool_1, ..., bool_{b-1}]
Per-sequence boolean indicating whether initial states should be used (True if input_pos > 0).
- any_prefill_use_initial_states: [bool]
Scalar boolean indicating whether any prefill sequence needs initial states. Precomputed on
the host to avoid GPU->CPU sync from torch.any() on the device tensor per layer.

### OTHER ARGUMENTS USED BY THE RUNTIME ########################################################
- extra_page_per_seq: [ep_0, ep_1, ..., ep_{b-1}]
Expand Down Expand Up @@ -527,6 +530,7 @@ def __init__(
("last_page_len", self.max_batch_size, torch.int),
("slot_idx", self.max_batch_size, torch.long),
### INFO OBJECTS THAT ARE AVAILABLE TO DESCRIBE THE INPUTS IN A MORE COMPACT WAY #######
("any_prefill_use_initial_states", 1, torch.bool),
("batch_info", 3, torch.int),
("max_seq_info", 4, torch.int),
### ADDITIONAL ARGUMENTS AVAILABLE THAT ARE DERIVED FROM THE BASIC ARGUMENTS ###########
Expand Down Expand Up @@ -1037,6 +1041,17 @@ def nest_sequences(
use_initial_states = ip_host > 0
self._stage_arg("use_initial_states", use_initial_states)

# precompute any(use_initial_states[:num_prefill]) on the host to avoid
# per-layer GPU->CPU sync from torch.any() inside cached ops
if self._is_required("any_prefill_use_initial_states"):
bi_host = self.get_arg("batch_info_host")
num_prefill = bi_host[0].item()
uis = self.get_arg("use_initial_states_host", truncate=True)
self._stage_arg(
"any_prefill_use_initial_states",
[bool(uis[:num_prefill].any())],
)

### UPDATE LOGITS GATHERING METADATA using heuristic if not provided #######################
# default is to gather all logits
if token_gather_indices is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def fla_cached_delta_rule(
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
any_prefill_use_initial_states_host: torch.Tensor,
# EXTRA METADATA
#
# CACHES
Expand Down Expand Up @@ -82,7 +83,8 @@ def fla_cached_delta_rule(

if num_prefill > 0:
initial_states = None
if torch.any(use_initial_states[:num_prefill]):
# Use precomputed host flag to avoid GPU->CPU sync from torch.any()
if any_prefill_use_initial_states_host.item():
initial_states = torch.where(
use_initial_states[:num_prefill, None, None, None],
delta_cache[slot_idx[:num_prefill]],
Expand Down Expand Up @@ -138,6 +140,7 @@ def fla_cached_delta_rule_fake(
cu_seqlen: torch.Tensor,
slot_idx: torch.Tensor,
use_initial_states: torch.Tensor,
any_prefill_use_initial_states_host: torch.Tensor,
# EXTRA METADATA
#
# CACHES
Expand Down Expand Up @@ -169,7 +172,13 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_standard_metadata_args(cls) -> List[str]:
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
return [
"batch_info_host",
"cu_seqlen",
"slot_idx",
"use_initial_states",
"any_prefill_use_initial_states_host",
]

@classmethod
def get_cache_initializers(
Expand Down
Loading
Loading