@@ -472,6 +472,27 @@ def _fork_rng():
472472 _set_all_rng_states (* current_states )
473473
474474
475+ # Global flag that's toggled whenever inside a checkpointing context
476+ IS_CHECKPOINTING = False
477+
478+
479+ def _set_checkpointing ():
480+ """Set state to checkpointing enabled."""
481+ global IS_CHECKPOINTING
482+ IS_CHECKPOINTING = True
483+
484+
485+ def _unset_checkpointing ():
486+ """Unset state to checkpointing enabled."""
487+ global IS_CHECKPOINTING
488+ IS_CHECKPOINTING = False
489+
490+
491+ def is_checkpointing ():
492+ """Check if currently in a checkpoint context."""
493+ return IS_CHECKPOINTING
494+
495+
475496class CheckpointFunction (torch .autograd .Function ):
476497 """Checkpoint Function
477498
@@ -484,6 +505,8 @@ class CheckpointFunction(torch.autograd.Function):
484505 @staticmethod
485506 def forward (ctx , run_function , distribute_saved_activations , * args ):
486507 """Forward pass."""
508+ _set_checkpointing ()
509+
487510 ctx .run_function = run_function
488511 ctx .distribute_saved_activations = distribute_saved_activations
489512
@@ -504,6 +527,7 @@ def forward(ctx, run_function, distribute_saved_activations, *args):
504527 # Store everything.
505528 ctx .save_for_backward (* args )
506529
530+ _unset_checkpointing ()
507531 return outputs
508532
509533 # pylint: disable=missing-function-docstring
@@ -515,6 +539,8 @@ def backward(ctx, *args):
515539 "Checkpointing is not compatible with .grad(), "
516540 "please use .backward() if possible"
517541 )
542+ _set_checkpointing ()
543+
518544 inputs = ctx .saved_tensors
519545 if ctx .distribute_saved_activations :
520546 safely_set_viewless_tensor_data (
@@ -539,6 +565,8 @@ def backward(ctx, *args):
539565 )
540566 torch .autograd .backward (outputs , args )
541567 grads = tuple (inp .grad if isinstance (inp , torch .Tensor ) else inp for inp in detached_inputs )
568+
569+ _unset_checkpointing ()
542570 return (None , None ) + grads
543571
544572
@@ -615,6 +643,14 @@ def __init__(self, fp8=False):
615643
616644 def checkpoint (self , run_function , * args ):
617645 """Checkpoint function."""
646+
647+ # If in cuda graph warmup, disable checkpointing, as 'discard_output_and_register_recompute'
648+ # may be called in a separate graph warmup.
649+ from megatron .core .transformer .cuda_graphs import is_graph_warmup
650+
651+ if is_graph_warmup ():
652+ return run_function (* args )
653+
618654 self .run_function = run_function
619655
620656 self .rng_states = _get_all_rng_states ()
@@ -628,11 +664,14 @@ def checkpoint(self, run_function, *args):
628664 def _recompute (self , _ ):
629665 """Used as a hook to recompute the output."""
630666
631- if self .ctx is None :
632- # The recomputation has been triggered already. Just return.
667+ from megatron .core .transformer .cuda_graphs import is_graph_capturing , is_graph_warmup
668+
669+ # The recomputation has been triggered already. Just return.
670+ # Handle cudagraphs, do nothing if currently in graph warmup
671+ if self .ctx is None or is_graph_warmup ():
633672 return
634673
635- if not torch .autograd ._is_checkpoint_valid ():
674+ if not torch .autograd ._is_checkpoint_valid () and not is_graph_capturing () :
636675 raise RuntimeError (
637676 "Checkpointing is not compatible with .grad(), "
638677 "please use .backward() if possible"
@@ -691,6 +730,12 @@ def discard_output_and_register_recompute(self, hook_tensor):
691730 in the forward pass and the gradient of the hook_tensor is computed before the recomputed
692731 tensors are used.
693732 """
733+
734+ from megatron .core .transformer .cuda_graphs import is_graph_warmup
735+
736+ if is_graph_warmup ():
737+ return
738+
694739 # use resize to release the output tensor memory and still keep the metadata in the tensors.
695740 # the metadata is still needed for backward
696741 for output in self .outputs :
0 commit comments