Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,23 @@ def dummy_forward(self):
self.inference_context.is_dynamic_batching()
and self.inference_context.num_speculative_tokens > 0
)
return self.model(tokens, position_ids, attention_mask, is_spec_decode=is_spec_decode)

# If the controller is running the graphed dummy_forward, it will run
# `_dynamic_step_forward_logits` instead. Setting this to False ensures that we
# will not try to match on a cudagraph when this is running eager.
self.inference_context._using_cuda_graph_this_step = False

# Pass inference_context so that transformer & mamba layers use the inference-mode
# cudagraph check (which gates on using_cuda_graph_this_step()) instead of
# the training-mode check (which unconditionally replays captured graphs
# when inference_context is not in kwargs).
return self.model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit scared that this will lead to segfaults. If we pass the inference context, we will attempt to write into the kv-cache/mamba-cache. But we are not setting up the required pointers properly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this in favour of gating on inference_mode instead

tokens,
position_ids,
attention_mask,
inference_context=self.inference_context,
is_spec_decode=is_spec_decode,
)

def _get_batch_size_and_seq_len(
self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ def _dynamic_step_calculate_top_n_logprobs(

def dummy_forward(self):
"""Perform a dummy forward pass. This is used in expert model parallelism
on ranks that do not have any real requests. It is meant to run in eager mode."""
on ranks that do not have any real requests. It may run in eager mode."""

context = self.inference_wrapped_model.inference_context
# if no cuda graphs, directly use dummy forward
Expand Down
6 changes: 1 addition & 5 deletions megatron/core/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs):
Check if we should call the local cudagraph path.
"""
# Training and validation mode CUDA graphs
if (
hasattr(self, 'cudagraph_manager')
and kwargs.get('inference_context') is None
and (self.training or _CudagraphGlobalRecord.cudagraph_created)
):
if hasattr(self, 'cudagraph_manager') and kwargs.get('inference_context') is None:
return True
elif not self.training and (
hasattr(self, 'cudagraph_manager')
Expand Down
6 changes: 1 addition & 5 deletions megatron/core/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,11 +1253,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs):
Check if we should call the local cudagraph path.
"""
# Training and validation mode CUDA graphs
if (
hasattr(self, 'cudagraph_manager')
and kwargs.get('inference_context') is None
and (self.training or _CudagraphGlobalRecord.cudagraph_created)
):
if hasattr(self, 'cudagraph_manager') and kwargs.get('inference_context') is None:
return True
# Inference mode. CUDA graphs are used in the decode phase only, when attn mask is None
elif not self.training and (
Expand Down
Loading