Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def update_cuda_graph_batch_sizes(self):
# if not set, use heuristic
if self.cuda_graph_batch_sizes is None:
cg_bs = {1, self.max_batch_size}
cg_bs.update(range(1, 128 + 1, 16))
# Only add batch sizes up to max_batch_size
cg_bs.update(range(1, min(128, self.max_batch_size) + 1, 16))
cg_bs.update(range(128, self.max_batch_size + 1, 128))
else:
cg_bs = [b for b in self.cuda_graph_batch_sizes if b <= self.max_batch_size]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ def _apply(
original_input.replace_all_uses_with(new_contiguous_node)
new_contiguous_node.replace_input_with(new_contiguous_node, original_input)

# Clean up the graph
if nodes_to_eliminate:
gm.graph.eliminate_dead_code()

info = TransformInfo(
skipped=False,
num_matches=len(nodes_to_eliminate),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,10 @@ def _apply(
graph.erase_node(activation_node)
graph.erase_node(conv_node)

gm.recompile()

info = TransformInfo(
skipped=False,
num_matches=len(matches),
is_clean=False,
has_valid_shapes=False,
is_clean=len(matches) == 0,
has_valid_shapes=len(matches) == 0,
)
return gm, info
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,5 @@ def _apply(
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=True,
has_valid_shapes=num_matches == 0,
)
21 changes: 11 additions & 10 deletions tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
from ..interface import (
Expand Down Expand Up @@ -112,8 +113,8 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t
# Delete the unstacked weights immediately to save GPU memory
# This will happen automatically after the graph is canonicalized,
# but for large models we'll run out of memory during the transformation itself.
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)

return fused_key_counter

Expand Down Expand Up @@ -635,7 +636,7 @@ def _apply(
graph.erase_node(final_hidden_state_node)

while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)

num_moe_patterns += 1

Expand Down Expand Up @@ -1272,14 +1273,14 @@ def _apply(
graph.erase_node(output_node)

# Clean up dead nodes
gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)

# Clean up dead inplace nodes in the region
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)

# Delete unused submodules/parameters
gm.delete_all_unused_submodules()
delete_all_unused_submodules(gm)

num_moe_patterns += 1

Expand Down Expand Up @@ -1517,8 +1518,8 @@ def _prepare_args_triton_format():
# Clean up after processing all nodes
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
# will remove the parameters/buffers that are no longer referenced
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)

return fused_key_counter

Expand Down Expand Up @@ -1776,8 +1777,8 @@ def _prepare_args_cutlass_format_nvfp4():
# Clean up after processing all nodes
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
# will remove the parameters/buffers that are no longer referenced
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)
return fused_key_counter


Expand Down
9 changes: 5 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_weight_name, is_linear_op, is_op
Expand Down Expand Up @@ -75,8 +76,8 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]:
n.replace_all_uses_with(get_split_node)

# Clean up deleted modules to save GPU memory
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)


def check_same_children(parent_node: Node, is_desired_child: Callable[[Node], bool]) -> bool:
Expand Down Expand Up @@ -185,8 +186,8 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]:
n.replace_all_uses_with(get_split_node)

# Clean up deleted modules to save GPU memory
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
eliminate_dead_code(gm)
delete_all_unused_submodules(gm)

def _apply_fusion_pass(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,6 @@ def _execute_op_in_aux_stream(
n.replace_all_uses_with(new_node)
graph.erase_node(n)
num_replaced += 1
if num_replaced:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()

return gm, num_replaced

Expand Down Expand Up @@ -322,8 +318,8 @@ def _apply(
info = TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=False,
has_valid_shapes=False,
is_clean=num_matches == 0,
has_valid_shapes=num_matches == 0,
)

return gm, info
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ...custom_ops.trtllm_dist import is_trtllm_op_available
from ...models.factory import ModelFactory, ShardingConfigSource
from ...shim.interface import CachedSequenceInterface
from ...utils._graph import del_attr_by_name
from ...utils._graph import del_attr_by_name, eliminate_dead_code
from ...utils.logger import ad_logger
from ...utils.node_utils import (
LayerSubgraph,
Expand Down Expand Up @@ -1444,7 +1444,7 @@ def get_partition(lst, world_size, rank):
node.replace_all_uses_with(dist_node)
dist_node.replace_input_with(dist_node, node)

gm.graph.eliminate_dead_code()
eliminate_dead_code(gm)
# Expert weights registered via gm.register_parameter() are top-level attributes.
# Unlike submodules, these aren't cleaned up by eliminate_dead_code() or
# delete_all_unused_submodules() - must delete manually after removing their get_attr nodes.
Expand Down
90 changes: 82 additions & 8 deletions tensorrt_llm/_torch/auto_deploy/utils/_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Graph-related utilities for transformations."""

import itertools
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, Optional, Set, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -129,8 +130,8 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
)
if recompile_graph:
# recompile graph to update self generated codes in subgraph
gm.graph.lint()
gm.recompile()
lint(gm)
recompile(gm)


def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None:
Expand Down Expand Up @@ -161,18 +162,91 @@ def _is_impure_node(node: Node) -> bool:
node.target._nondeterministic_seeded = True


def delete_all_unused_submodules(gm: GraphModule) -> None:
"""Optimized version of delete_all_unused_submodules with O(n+m) complexity.

The original implementation uses a list for tracking used modules, making membership
checks O(n). This version uses a set for O(1) lookups.

Original implementation is at GraphModule.delete_all_unused_submodules

Args:
gm: The GraphModule to clean up.
"""
used: Set[str] = set()

for node in itertools.chain(
gm.graph.find_nodes(op="call_module", sort=False),
gm.graph.find_nodes(op="get_attr", sort=False),
):
# check if it's already used and it's not a call_module node
# in this case we can skip. We cannot skip if it's a call_module node because we need to
# mark all recursive submodules as used.
if node.target in used and node.op != "call_module":
continue

# A list of strings representing the different parts
# of the path. For example, `foo.bar.baz` gives us
# ["foo", "bar", "baz"]
fullpath = node.target.split(".")

# Progressively collect all the names of intermediate
# modules. For example, if we have the target
# `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
# `foo.bar.baz` to the list.
used.update(".".join(fullpath[:i]) for i in range(1, len(fullpath) + 1))

# For call_module, also mark all recursive submodules as used
if node.op == "call_module":
try:
submod = gm.get_submodule(node.target)
for submod_name, _ in submod.named_modules():
if submod_name != "":
used.add(f"{node.target}.{submod_name}")
except AttributeError:
# Node referenced nonexistent submodule, don't need to
# worry about GCing anything
pass

# also add the root module to the used set
used.add("")

# Go over all modules and delete if on the list. Since we use named_modules, parents will be
# deleted first and children will be automatically skipped inside delete_submodule.
to_delete = [name for name, _ in gm.named_modules() if name not in used]
for name in to_delete:
gm.delete_submodule(name)


def eliminate_dead_code(
gm: GraphModule, is_impure_node: Optional[Callable[[Node], bool]] = None
) -> None:
"""Eliminate dead code from the graph of the given GraphModule."""
gm.graph.eliminate_dead_code(is_impure_node=is_impure_node)


def recompile(gm: GraphModule) -> None:
"""Recompile the graph of the given GraphModule."""
gm.recompile()


def lint(gm: GraphModule) -> None:
"""Lint the graph of the given GraphModule."""
gm.graph.lint()


def _canonicalize_single_gm(gm: GraphModule) -> None:
# clean up graph (needs to be done repeatedly until no more dead code)
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
eliminate_dead_code(gm, is_impure_node=_is_impure_node)

# recompile to propagate all graph changes to the graph module
gm.recompile()
recompile(gm)

# clean up graph module
gm.delete_all_unused_submodules()
delete_all_unused_submodules(gm)

# lint the graph
gm.graph.lint()
lint(gm)


def canonicalize_graph(mod: nn.Module) -> None:
Expand Down Expand Up @@ -217,7 +291,7 @@ def _run_shape_prop_single_gm(
ad_logger.warning("No fake tensors and no args available for shape propagation")

# lint the graph
gm.graph.lint()
lint(gm)


def run_shape_prop(
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_default_kwargs(self, enable_chunked_prefill=False):
},
"compile_model": {
"backend":
"torch-opt",
"torch-cudagraph",
"cuda_graph_batch_sizes":
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
},
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ full:H100_PCIe/unittest/llmapi/test_llm_pytorch.py::test_llama_7b_multi_lora_evi
unittest/_torch/speculative/test_draft_len_schedule.py::test_correctness_across_batch_sizes[model_drafter-schedule1] SKIP (https://nvbugs/5680911)
accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype SKIP (https://nvbugs/5612438)
accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True] SKIP (https://nvbugs/5688721)
accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4] SKIP (https://nvbugs/5769712)
test_e2e.py::test_openai_completions_example[trt] SKIP (https://nvbugs/5701450)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5701457)
triton_server/test_triton_llm.py::test_llmapi_backend[4-0-disableDecoupleMode-tensorrt_llm] SKIP (https://nvbugs/5701480)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SplitDimension,
WeightShardingInfo,
)
from tensorrt_llm._torch.auto_deploy.utils._graph import recompile
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm.commands.bench import main
from tensorrt_llm.functional import AllReduceStrategy
Expand Down Expand Up @@ -378,7 +379,7 @@ def forward(self, x):
if node:
transform.check_and_apply(gm, node)

gm.recompile()
recompile(gm)

# Verify the graph contains torch_dist_all_reduce nodes with correct strategy
allreduce_nodes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ShardingTransformConfig,
)
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils._graph import lint, recompile
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm.functional import AllReduceStrategy

Expand Down Expand Up @@ -188,8 +189,8 @@ def test_llama4_stacked_moe_pattern_detection():
)
graph.output(moe_node)

graph.lint()
gm.recompile()
lint(gm)
recompile(gm)

# Run pattern detection for EP
optimizer = InferenceOptimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
cuda_stream_manager,
record_event_wrapper,
)
from tensorrt_llm._torch.auto_deploy.utils._graph import canonicalize_graph
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op


Expand Down Expand Up @@ -75,9 +76,7 @@ def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tupl
num_replaced += 1

if num_replaced:
graph.eliminate_dead_code()
graph.lint()
gm.recompile()
canonicalize_graph(gm)

return gm, num_replaced

Expand Down
Loading
Loading