Skip to content

Commit 1cfa834

Browse files
authored
Add full model cuda graph support for MTP inference (NVIDIA#4950)
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
1 parent b60de39 commit 1cfa834

7 files changed

Lines changed: 366 additions & 35 deletions

File tree

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from megatron.core.package_info import __version__ as mcore_version
3737
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
38+
from megatron.core.transformer.enums import InferenceCudaGraphScope
3839
from megatron.core.transformer.moe.token_dispatcher_inference import (
3940
InferenceAllGatherDispatcherBase,
4041
NCCLAllGatherDispatcher,
@@ -555,6 +556,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
555556

556557
# Initialize context state.
557558
self.params_dtype = model_config.params_dtype
559+
self.hidden_size = model_config.hidden_size
560+
self.inference_cuda_graph_scope = model_config.inference_cuda_graph_scope
558561
self.max_sequence_length = inference_config.max_sequence_length
559562

560563
# Block ids. With speculative decoding, blocks are pre-allocated when the
@@ -698,6 +701,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
698701
self.use_flashinfer_fused_rope = inference_config.use_flashinfer_fused_rope
699702
self.inference_grouped_gemm_backend = model_config.inference_grouped_gemm_backend
700703

704+
# Placeholder for the MTP decoder hidden-states buffer; allocated inside
705+
# initialize_all_tensors() when num_speculative_tokens > 0.
706+
self.mtp_decoder_hidden_states = None
707+
701708
# Allocate GPU state.
702709
self.is_tensor_state_allocated = False
703710
self.initialize_all_tensors()
@@ -1270,6 +1277,23 @@ def initialize_all_tensors(self) -> None:
12701277
):
12711278
self._allocate_mamba_cache(self.config.prefix_caching_mamba_gb)
12721279

1280+
# MTP speculative decoding: persistent buffer for decoder hidden states.
1281+
# Only needed for block-scope CUDA graphs, where the Python assignment in
1282+
# forward() runs only during graph capture. Using copy_() into a fixed
1283+
# buffer ensures every batch-size graph replay writes to the same GPU
1284+
# address. Sized to max_tokens; only [:actual_tokens] is valid each step.
1285+
if (
1286+
self.num_speculative_tokens > 0
1287+
and self.inference_cuda_graph_scope == InferenceCudaGraphScope.block
1288+
):
1289+
self.mtp_decoder_hidden_states = torch.empty(
1290+
self.max_tokens,
1291+
1,
1292+
self.hidden_size,
1293+
device=torch.cuda.current_device(),
1294+
dtype=self.params_dtype,
1295+
)
1296+
12731297
# Reset tensor-related metadata.
12741298
self.reset_metadata()
12751299

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
gather_from_sequence_parallel_region,
3737
scatter_to_sequence_parallel_region,
3838
)
39+
from megatron.core.transformer.enums import InferenceCudaGraphScope
3940
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
4041
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
4142
from megatron.core.transformer.utils import set_model_to_sequence_parallel
@@ -769,13 +770,11 @@ def _compute_serial_mtp_and_sample(self):
769770
unwrapped_model = self._unwrapped_model
770771

771772
# On non-last pipeline stages, the model won't have decoder hidden states.
772-
has_mtp = self._is_last_pp_stage and hasattr(
773-
unwrapped_model, '_decoder_hidden_states_cache'
774-
)
773+
has_mtp = self._is_last_pp_stage and context.mtp_decoder_hidden_states is not None
775774

776775
if has_mtp:
777776
# Get decoder hidden states at last accepted positions.
778-
hidden_states = unwrapped_model._decoder_hidden_states_cache
777+
hidden_states = context.mtp_decoder_hidden_states
779778

780779
# When SP is active the decoder output is in scattered format
781780
# [S/TP, B, H], but _last_accepted_seq_indices are indices into
@@ -889,9 +888,12 @@ def _compute_serial_mtp_and_sample(self):
889888
next_token_ids = spec_tokens
890889
nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}")
891890

892-
# Clean up cached hidden states.
893-
if has_mtp:
894-
del unwrapped_model._decoder_hidden_states_cache
891+
# In eager mode forward() assigns the hidden states tensor directly to
892+
# the context attribute; release it so the tensor can be garbage
893+
# collected. In block-scope CUDA graph mode the attribute is a
894+
# pre-allocated fixed buffer that must persist across replays.
895+
if has_mtp and context.inference_cuda_graph_scope != InferenceCudaGraphScope.block:
896+
context.mtp_decoder_hidden_states = None
895897

896898
def _verify_speculative_tokens(
897899
self,
@@ -1517,15 +1519,13 @@ def _dummy_serial_mtp_forward(self):
15171519
if self.model_config.expert_model_parallel_size <= 1:
15181520
return
15191521

1520-
unwrapped_model = self._unwrapped_model
1521-
1522-
has_mtp = self._is_last_pp_stage and hasattr(
1523-
unwrapped_model, '_decoder_hidden_states_cache'
1524-
)
1522+
context = self.inference_wrapped_model.inference_context
1523+
has_mtp = self._is_last_pp_stage and context.mtp_decoder_hidden_states is not None
15251524
if not has_mtp and not self.model_is_pipeline_parallel:
15261525
# No MTP on this rank and no PP broadcast to participate in.
15271526
return
15281527

1528+
unwrapped_model = self._unwrapped_model
15291529
device = torch.cuda.current_device()
15301530
dtype = self.model_config.params_dtype
15311531
hidden_size = self.model_config.hidden_size

megatron/core/models/gpt/gpt_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from megatron.core.process_groups_config import ProcessGroupCollection
2828
from megatron.core.quantization.utils import get_quant_config_or_none
2929
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
30-
from megatron.core.transformer.enums import ModelType
30+
from megatron.core.transformer.enums import InferenceCudaGraphScope, ModelType
3131
from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler
3232
from megatron.core.transformer.multi_token_prediction import (
3333
MultiTokenPredictionBlock,
@@ -668,7 +668,14 @@ def _postprocess(
668668
if in_inference_mode or is_spec_decode:
669669
# Cache decoder hidden states for serial MTP computation
670670
# after speculative token verification.
671-
self._decoder_hidden_states_cache = hidden_states
671+
if inference_context is not None:
672+
if self.config.inference_cuda_graph_scope == InferenceCudaGraphScope.block:
673+
assert inference_context.mtp_decoder_hidden_states is not None
674+
inference_context.mtp_decoder_hidden_states[: hidden_states.shape[0]].copy_(
675+
hidden_states
676+
)
677+
else:
678+
inference_context.mtp_decoder_hidden_states = hidden_states
672679
else:
673680
# In training/eval, use the utility function for processing MTP loss/scaling.
674681
hidden_states = process_mtp_loss(

megatron/core/models/hybrid/hybrid_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,19 @@ def forward(
544544
if self.config.mtp_num_layers is not None and self.mtp_process:
545545
assert self.config.mtp_num_layers > 0
546546
if in_inference_mode or is_spec_decode:
547-
self._decoder_hidden_states_cache = hidden_states
547+
if inference_context is not None:
548+
if self.config.inference_cuda_graph_scope == InferenceCudaGraphScope.block:
549+
# Block-scope CUDA graph mode: copy_() into the
550+
# pre-allocated buffer so every graph replay writes to
551+
# the same fixed GPU address regardless of batch size.
552+
assert inference_context.mtp_decoder_hidden_states is not None
553+
inference_context.mtp_decoder_hidden_states[: hidden_states.shape[0]].copy_(
554+
hidden_states
555+
)
556+
else:
557+
# Non-block scope: direct assignment; the controller will set
558+
# this back to None after reading to allow GC.
559+
inference_context.mtp_decoder_hidden_states = hidden_states
548560
else:
549561
# For RL (labels is None), process_mtp_loss derives labels from
550562
# input_ids to match the SFT label format.

tests/unit_tests/inference/engines/test_dynamic_engine.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,7 +2581,7 @@ def mock_mtp_forward(*args, **kwargs):
25812581
base_logits[:, :, 0] = 100.0 # High probability for token 0
25822582

25832583
# Cache hidden states for serial MTP computation
2584-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
2584+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
25852585
tokens.size(1), 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
25862586
)
25872587
if test_config.materialize_only_last_token_logits:
@@ -2720,7 +2720,7 @@ def mock_deterministic_forward(*args, **kwargs):
27202720
base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0)
27212721

27222722
# Cache hidden states for serial MTP computation
2723-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
2723+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
27242724
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
27252725
)
27262726
if test_config.materialize_only_last_token_logits:
@@ -2815,7 +2815,7 @@ def mock_deterministic_forward(*args, **kwargs):
28152815
base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0)
28162816

28172817
# Cache hidden states for serial MTP computation
2818-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
2818+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
28192819
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
28202820
)
28212821
if test_config.materialize_only_last_token_logits:
@@ -2911,7 +2911,7 @@ def mock_deterministic_forward(*args, **kwargs):
29112911
base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0)
29122912

29132913
# Cache hidden states for serial MTP computation
2914-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
2914+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
29152915
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
29162916
)
29172917
if test_config.materialize_only_last_token_logits:
@@ -3187,7 +3187,7 @@ def mock_mtp_forward(*args, **kwargs):
31873187
next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1)
31883188
base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0)
31893189

3190-
model._decoder_hidden_states_cache = torch.zeros(
3190+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
31913191
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
31923192
)
31933193
if test_config.materialize_only_last_token_logits:
@@ -3308,7 +3308,7 @@ def mock_safe_forward(*args, **kwargs):
33083308
base_logits[:, :, 0] = 100.0 # Force model to deterministically pick token 0
33093309

33103310
# Cache hidden states for serial MTP computation
3311-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
3311+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
33123312
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
33133313
)
33143314
if test_config.materialize_only_last_token_logits:
@@ -3526,7 +3526,7 @@ def mock_mtp_forward(*args, **kwargs):
35263526
dtype=torch.bfloat16,
35273527
)
35283528
base_logits[:, :, 0] = 100.0
3529-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
3529+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
35303530
tokens.size(1), 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
35313531
)
35323532
return base_logits
@@ -3669,7 +3669,7 @@ def mock_deterministic_forward(*args, **kwargs):
36693669
)
36703670
# Make token 0 very likely so speculative tokens get accepted.
36713671
base_logits[:, :, 0] = 100.0
3672-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
3672+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
36733673
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
36743674
)
36753675
return base_logits
@@ -3791,7 +3791,7 @@ def mock_deterministic_forward(*args, **kwargs):
37913791
b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16
37923792
)
37933793
base_logits[:, :, 0] = 100.0
3794-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
3794+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
37953795
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
37963796
)
37973797
return base_logits
@@ -3923,7 +3923,7 @@ def mock_deterministic_forward(*args, **kwargs):
39233923
b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16
39243924
)
39253925
base_logits[:, :, 0] = 100.0
3926-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
3926+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
39273927
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
39283928
)
39293929
return base_logits
@@ -4178,7 +4178,7 @@ def mock_deterministic_forward(*args, **kwargs):
41784178
)
41794179
next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1)
41804180
base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0)
4181-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
4181+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
41824182
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
41834183
)
41844184
return base_logits
@@ -4276,7 +4276,7 @@ def mock_deterministic_forward(*args, **kwargs):
42764276
)
42774277
next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1)
42784278
base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0)
4279-
unwrapped_model._decoder_hidden_states_cache = torch.zeros(
4279+
env.engine.context.mtp_decoder_hidden_states = torch.zeros(
42804280
s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16
42814281
)
42824282
return base_logits

0 commit comments

Comments
 (0)