Skip to content

Commit 8baf014

Browse files
jiemingzbuptzyb
andauthored
Various CUDA graph improvements on capture time, replay time, memory footprint (NVIDIA#2572)
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> Signed-off-by: Jieming Zhang <jiemingz@nvidia.com> Co-authored-by: Robin Zhang <robinz@nvidia.com>
1 parent dd72aee commit 8baf014

File tree

13 files changed

+1125
-637
lines changed

13 files changed

+1125
-637
lines changed

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,9 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor):
545545
"""
546546
residual = node.layer_state.residual
547547
shared_expert_output = getattr(node.layer_state, 'shared_expert_output', None)
548-
output = layer.mlp.combine(output, shared_expert_output)
548+
output = layer.mlp.combine(output)
549+
output = layer.mlp.postprocess(output, shared_expert_output)
550+
549551
mlp_output_with_bias = (output, None)
550552
if hasattr(layer, 'cuda_graphs') and layer.cuda_graphs:
551553
layer.mlp.cudagraph_tensor_store.clear()

megatron/core/models/mamba/mamba_layer_specs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from megatron.core.transformer.enums import AttnMaskType
2121
from megatron.core.transformer.mlp import MLP, MLPSubmodules
2222
from megatron.core.transformer.spec_utils import ModuleSpec
23-
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
23+
from megatron.core.transformer.transformer_layer import (
24+
MoETransformerLayer,
25+
TransformerLayer,
26+
TransformerLayerSubmodules,
27+
)
2428

2529
moe = get_moe_module_spec(
2630
use_te=True,
@@ -78,8 +82,7 @@
7882
),
7983
),
8084
moe_layer=ModuleSpec(
81-
# TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs?
82-
module=TransformerLayer,
85+
module=MoETransformerLayer,
8386
submodules=TransformerLayerSubmodules(
8487
pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add
8588
),

megatron/core/ssm/mamba_layer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from megatron.core.inference.contexts import BaseInferenceContext
1717
from megatron.core.packed_seq_params import PackedSeqParams
1818
from megatron.core.process_groups_config import ProcessGroupCollection
19+
from megatron.core.transformer.enums import CudaGraphScope
1920
from megatron.core.transformer.identity_op import IdentityOp
2021
from megatron.core.transformer.module import GraphableMegatronModule
2122
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
@@ -85,6 +86,13 @@ def __init__(
8586
self.mamba_bda = build_module(submodules.mamba_bda)
8687
self.bias_dropout_add_exec_handler = torch.enable_grad
8788

89+
def create_mcore_cudagraph_manager(self, config):
90+
"""Register the mamba layer for cudagraphs."""
91+
from megatron.core.transformer.cuda_graphs import CudaGraphManager
92+
93+
if not self.config.cuda_graph_scope or CudaGraphScope.mamba in self.config.cuda_graph_scope:
94+
self.cudagraph_manager = CudaGraphManager(config)
95+
8896
def mamba_state_shapes_per_request(self) -> Tuple[Tuple[int], Tuple[int]]:
8997
"""Returns the Mamba conv and ssm states shapes per request."""
9098
return self.mixer.mamba_state_shapes_per_request()

megatron/core/tensor_parallel/random.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,27 @@ def _fork_rng():
472472
_set_all_rng_states(*current_states)
473473

474474

475+
# Global flag that's toggled whenever inside a checkpointing context
476+
IS_CHECKPOINTING = False
477+
478+
479+
def _set_checkpointing():
480+
"""Set state to checkpointing enabled."""
481+
global IS_CHECKPOINTING
482+
IS_CHECKPOINTING = True
483+
484+
485+
def _unset_checkpointing():
486+
"""Unset state to checkpointing enabled."""
487+
global IS_CHECKPOINTING
488+
IS_CHECKPOINTING = False
489+
490+
491+
def is_checkpointing():
492+
"""Check if currently in a checkpoint context."""
493+
return IS_CHECKPOINTING
494+
495+
475496
class CheckpointFunction(torch.autograd.Function):
476497
"""Checkpoint Function
477498
@@ -484,6 +505,8 @@ class CheckpointFunction(torch.autograd.Function):
484505
@staticmethod
485506
def forward(ctx, run_function, distribute_saved_activations, *args):
486507
"""Forward pass."""
508+
_set_checkpointing()
509+
487510
ctx.run_function = run_function
488511
ctx.distribute_saved_activations = distribute_saved_activations
489512

@@ -504,6 +527,7 @@ def forward(ctx, run_function, distribute_saved_activations, *args):
504527
# Store everything.
505528
ctx.save_for_backward(*args)
506529

530+
_unset_checkpointing()
507531
return outputs
508532

509533
# pylint: disable=missing-function-docstring
@@ -515,6 +539,8 @@ def backward(ctx, *args):
515539
"Checkpointing is not compatible with .grad(), "
516540
"please use .backward() if possible"
517541
)
542+
_set_checkpointing()
543+
518544
inputs = ctx.saved_tensors
519545
if ctx.distribute_saved_activations:
520546
safely_set_viewless_tensor_data(
@@ -539,6 +565,8 @@ def backward(ctx, *args):
539565
)
540566
torch.autograd.backward(outputs, args)
541567
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
568+
569+
_unset_checkpointing()
542570
return (None, None) + grads
543571

544572

@@ -615,6 +643,14 @@ def __init__(self, fp8=False):
615643

616644
def checkpoint(self, run_function, *args):
617645
"""Checkpoint function."""
646+
647+
# If in cuda graph warmup, disable checkpointing, as 'discard_output_and_register_recompute'
648+
# may be called in a separate graph warmup.
649+
from megatron.core.transformer.cuda_graphs import is_graph_warmup
650+
651+
if is_graph_warmup():
652+
return run_function(*args)
653+
618654
self.run_function = run_function
619655

620656
self.rng_states = _get_all_rng_states()
@@ -628,11 +664,14 @@ def checkpoint(self, run_function, *args):
628664
def _recompute(self, _):
629665
"""Used as a hook to recompute the output."""
630666

631-
if self.ctx is None:
632-
# The recomputation has been triggered already. Just return.
667+
from megatron.core.transformer.cuda_graphs import is_graph_capturing, is_graph_warmup
668+
669+
# The recomputation has been triggered already. Just return.
670+
# Handle cudagraphs, do nothing if currently in graph warmup
671+
if self.ctx is None or is_graph_warmup():
633672
return
634673

635-
if not torch.autograd._is_checkpoint_valid():
674+
if not torch.autograd._is_checkpoint_valid() and not is_graph_capturing():
636675
raise RuntimeError(
637676
"Checkpointing is not compatible with .grad(), "
638677
"please use .backward() if possible"
@@ -691,6 +730,12 @@ def discard_output_and_register_recompute(self, hook_tensor):
691730
in the forward pass and the gradient of the hook_tensor is computed before the recomputed
692731
tensors are used.
693732
"""
733+
734+
from megatron.core.transformer.cuda_graphs import is_graph_warmup
735+
736+
if is_graph_warmup():
737+
return
738+
694739
# use resize to release the output tensor memory and still keep the metadata in the tensors.
695740
# the metadata is still needed for backward
696741
for output in self.outputs:

0 commit comments

Comments
 (0)