Skip to content

Commit 72dc7ec

Browse files
authored
[None][fix] Fix multi_stream_moe accuracy with MLIR and piecewise cudagraphs (NVIDIA#12847)
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
1 parent c4889d9 commit 72dc7ec

File tree

7 files changed

+449
-9
lines changed

7 files changed

+449
-9
lines changed

examples/auto_deploy/model_registry/configs/gemma4_moe.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ transforms:
2626
enabled: true
2727
fuse_gemms:
2828
enabled: true
29+
multi_stream_moe:
30+
enabled: true

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,22 @@ def prepare(self) -> None:
345345
gm = GraphModule(model, copy.deepcopy(model.graph))
346346

347347
self.split_info = split_graph_at_dynamic_ops(gm)
348+
349+
# When multi-stream transforms reclassify ALL static partitions as
350+
# dynamic (e.g. multi_stream_moe + multi_stream_mla_attn on every
351+
# layer), there are zero capturable static segments. Piecewise CUDA
352+
# graphs are impossible — fall back to eager execution for
353+
# prefill/mixed batches (monolithic CG still handles decode).
354+
if not self.split_info.static_submod_indices:
355+
ad_logger.warning(
356+
"PiecewiseCapturedGraph: no static partitions after splitting "
357+
"(%d dynamic). Piecewise CUDA graphs disabled — prefill/mixed "
358+
"batches will run eagerly.",
359+
len(self.split_info.dynamic_submod_indices),
360+
)
361+
self._is_prepared = True
362+
return
363+
348364
self.split_gm = self.split_info.split_gm
349365

350366
graph_pool = torch.cuda.graph_pool_handle()
@@ -408,6 +424,17 @@ def prepare(self) -> None:
408424
self.split_info.static_submod_indices + self.split_info.dynamic_submod_indices
409425
)
410426
current_static_runner: Optional[ADPiecewiseRunner] = None
427+
# Fallback runner: the first available static runner. When
428+
# multi-stream transforms reclassify the initial static partition(s)
429+
# as dynamic (e.g. record_event_passthrough from multi_stream_mla_attn)
430+
# AND the static partitions between metadata-prep and attention have
431+
# no CUDA ops (skipped), there is no *preceding* static runner for the
432+
# first attention op. In that case we fall back to the nearest
433+
# *following* static runner — any runner in the shared graph pool can
434+
# host the pre-allocated output buffer.
435+
fallback_runner: Optional[ADPiecewiseRunner] = None
436+
if runner_by_idx:
437+
fallback_runner = runner_by_idx[min(runner_by_idx)]
411438
num_metadata_wrapped = 0
412439
for idx in all_submod_indices:
413440
if idx in runner_by_idx:
@@ -430,15 +457,23 @@ def prepare(self) -> None:
430457
)
431458
continue
432459

433-
assert current_static_runner is not None, (
434-
f"Dynamic {submod_name} has no preceding static runner — "
460+
effective_runner = current_static_runner or fallback_runner
461+
assert effective_runner is not None, (
462+
f"Dynamic {submod_name} has no static runner available — "
435463
f"cannot allocate out= buffer for stable output addresses"
436464
)
465+
if current_static_runner is None:
466+
ad_logger.info(
467+
"PiecewiseCapturedGraph: %s has no preceding static "
468+
"runner, using fallback runner (submod_%d)",
469+
submod_name,
470+
min(runner_by_idx),
471+
)
437472

438473
_inject_out_param(submod)
439474
wrapper = DynamicOpWrapper(
440475
submod,
441-
preceding_runner=current_static_runner,
476+
preceding_runner=effective_runner,
442477
dynamic_submod_id=idx,
443478
)
444479
setattr(self.split_gm, submod_name, wrapper)

tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,20 @@
8686
"auto_deploy::cuda_cached_causal_conv1d",
8787
]
8888

89+
# Multi-stream passthrough functions that switch the CUDA current stream.
90+
# Static partitions containing these functions cannot be captured as CUDA
91+
# graphs because the host-side stream synchronization required for
92+
# correctness (caller_stream.synchronize) is not capturable. Such
93+
# partitions are reclassified as dynamic so they run eagerly.
94+
_STREAM_SWITCH_FUNCTION_NAMES = frozenset(
95+
{
96+
"begin_aux_stream_passthrough",
97+
"end_aux_stream_passthrough",
98+
"wait_aux_stream_passthrough",
99+
"record_event_passthrough",
100+
}
101+
)
102+
89103

90104
def _get_all_dynamic_op_names() -> Set[str]:
91105
"""Return the full set of dynamic op qualified names."""
@@ -193,11 +207,17 @@ def needs_out_buffer(submod: nn.Module) -> bool:
193207
194208
Inplace ops (mutate input, return None) don't produce new tensors.
195209
Metadata prep ops are handled by MetadataWrapper (stable output addresses).
196-
Both are skipped — only attention/SSM/delta/logits ops need out= buffers.
210+
Multi-stream partitions reclassified as dynamic run eagerly and manage
211+
their own output tensors — they do not need out= buffers.
212+
All of these are skipped — only attention/SSM/delta/logits ops need out= buffers.
197213
"""
198214
if not isinstance(submod, GraphModule):
199215
return True
200216

217+
# Multi-stream partitions (reclassified from static) do not need out= buffers.
218+
if _submod_has_stream_switch(submod):
219+
return False
220+
201221
for node in submod.graph.nodes:
202222
if node.op == "call_function" and is_dynamic_cached_op(node):
203223
op_name = node.target.name() if hasattr(node.target, "name") else str(node.target)
@@ -232,6 +252,16 @@ def is_metadata_prep(submod: nn.Module) -> bool:
232252
# ---------------------------------------------------------------------------
233253

234254

255+
def _submod_has_stream_switch(submod: GraphModule) -> bool:
256+
"""Return True if *submod* contains a multi-stream passthrough function."""
257+
for node in submod.graph.nodes:
258+
if node.op == "call_function":
259+
func_name = getattr(node.target, "__name__", "")
260+
if func_name in _STREAM_SWITCH_FUNCTION_NAMES:
261+
return True
262+
return False
263+
264+
235265
@dataclass
236266
class SplitInfo:
237267
"""Metadata about a split GraphModule."""
@@ -318,6 +348,26 @@ def partition_fn(node: Node) -> int:
318348

319349
submod_names.sort(key=lambda n: int(n.split("_")[1]))
320350

351+
# Reclassify static partitions that contain multi-stream passthrough
352+
# functions as dynamic. These partitions switch the CUDA current stream
353+
# at runtime, which requires a host-side caller_stream.synchronize() for
354+
# correctness with MLIR-fused Triton kernels. Since synchronize() cannot
355+
# be called during CUDA graph capture, such partitions must run eagerly.
356+
num_reclassified = 0
357+
for name in submod_names:
358+
pid = int(name.split("_")[1])
359+
if pid in dynamic_partitions:
360+
continue
361+
submod = getattr(split_gm, name)
362+
if isinstance(submod, GraphModule) and _submod_has_stream_switch(submod):
363+
dynamic_partitions.add(pid)
364+
num_reclassified += 1
365+
if num_reclassified:
366+
ad_logger.info(
367+
f"Piecewise split: reclassified {num_reclassified} static partition(s) "
368+
"as dynamic (contain multi-stream passthrough ops)"
369+
)
370+
321371
dynamic_indices = []
322372
static_indices = []
323373
for name in submod_names:

tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,21 @@ def _apply(
6262
node_to_gather = lm_head_node.all_input_nodes[0]
6363
self._log_info(f"Found LM head node: {lm_head_node.name}")
6464
else:
65-
# Walk backward through elementwise/unary ops (e.g. softcapping: div, tanh, mul)
66-
# to find the actual lm_head linear node.
65+
# Walk backward through SINGLE-INPUT elementwise/unary ops
66+
# (e.g. Gemma4 softcapping: linear → div → tanh → mul) to find the
67+
# actual lm_head linear node. Only follow nodes that have exactly
68+
# one tensor input to avoid branching into the model body (e.g.
69+
# residual adds, fused allreduce+norm ops).
6770
current = lm_head_node
6871
while current is not None and not is_linear_op(current):
69-
inputs = current.all_input_nodes
70-
current = inputs[0] if len(inputs) >= 1 else None
72+
tensor_inputs = [n for n in current.all_input_nodes if n.op != "get_attr"]
73+
if len(tensor_inputs) != 1:
74+
# Multi-input or no-input node — stop walking; the lm_head
75+
# is not in this graph (common for VLMs where only the text
76+
# backbone is exported and the lm_head is applied externally).
77+
current = None
78+
break
79+
current = tensor_inputs[0]
7180

7281
if current is not None and is_linear_op(current):
7382
node_to_gather = current.all_input_nodes[0]
@@ -76,7 +85,10 @@ def _apply(
7685
)
7786
else:
7887
node_to_gather = lm_head_node
79-
self._log_info("lm_head node is not linear, using it as the node to gather")
88+
self._log_info(
89+
f"lm_head linear not in graph; inserting gather before "
90+
f"output node ({lm_head_node.name})"
91+
)
8092

8193
# Add logits_gather_mask as input in the graph and the sequence info interface
8294
logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "token_gather_indices")

tensorrt_llm/_torch/auto_deploy/utils/multi_stream_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,19 @@ def begin_aux_stream_passthrough(
156156
# which is NOT ``torch.cuda.default_stream()``.
157157
caller_stream = torch.cuda.current_stream(device)
158158
cuda_stream_manager._caller_streams[device] = caller_stream
159+
# Synchronize the caller stream before switching to aux. The GPU-side
160+
# event wait (aux_stream.wait_event) alone is NOT sufficient when
161+
# MLIR-generated Triton kernels precede this point: their interaction
162+
# with PyTorch's CUDA caching allocator can cause the allocator to
163+
# recycle memory that the aux stream still needs, leading to illegal
164+
# memory accesses or silent data corruption. A CPU-side synchronize
165+
# ensures all caller-stream GPU work has retired before aux-stream
166+
# allocations begin.
167+
# NOTE: this cannot be called during CUDA graph capture. The cudagraph
168+
# path must rely on event-based sync only; a separate fix is needed
169+
# there (see TRTLLM multi_stream_moe + MLIR tracking).
170+
if not torch.cuda.is_current_stream_capturing():
171+
caller_stream.synchronize()
159172
# Record where the caller's stream has reached so aux knows when data is ready.
160173
main_event = cuda_stream_manager.get_event(device, cuda_stream_manager.MAIN_STREAM_NAME)
161174
main_event.record(caller_stream)

0 commit comments

Comments
 (0)