Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def update(

has_chunked_prefill_req = enable_chunked_prefill and real_prefill_count > 0

# Mamba models cannot have padded prefill requests without real prefill requests.
# If this happens (e.g. on EP dummy ranks with incorrect setup), cu_seqlens will
# be all zeros, causing an illegal memory access in the Mamba SSM prefill kernel.
assert not (padded_prefill_count > 0 and real_prefill_count == 0), (
f"Mamba models require real prefill requests when padded prefill count > 0. "
f"Real prefill: {real_prefill_count}, Padded prefill: {padded_prefill_count}. "
f"This can occur on EP dummy ranks if they are not properly set up to match "
f"the selected CUDA graph's prefill/decode split."
)

# Although the context ensures that the last request is always the designated
# chunked prefill request, what we actually care about is ensuring that any
# prefill request with non-zero initial states is executed through the
Expand Down
96 changes: 75 additions & 21 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,9 @@ def num_decode_requests(self) -> int:
"""
return self.total_request_count - self.paused_request_count - self.num_prefill_requests

def add_dummy_requests_for_expert_parallel_step(self) -> None:
def add_dummy_requests_for_expert_parallel_step(
self, target_graph: Optional[InferenceBatchDimensions] = None
) -> None:
"""Minimal context setup so an EP rank with no real requests can replay
an already-captured cuda graph without crashing or corrupting memory.

Expand All @@ -1291,37 +1293,73 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None:
We setup minimal state such the initialize_attention_state and the forward
pass can run without error.

Args:
target_graph: The matched CUDA graph dimensions to set up dummy requests for.
If None, uses the smallest decode-only graph. When provided, the dummy
setup will match the graph's prefill/decode split, which is critical for
Mamba hybrid models to avoid all-zero cu_seqlens on EP dummy ranks.
"""
smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list)
# the smallest cuda graph is decode only.
assert smallest_cuda_graph_dimensions.prefill_req_count == 0

N = smallest_cuda_graph_dimensions.decode_req_count
if target_graph is not None:
target_dims = target_graph
else:
target_dims = min(self.cuda_graph_batch_dimensions_list)
# the smallest cuda graph is decode only.
assert target_dims.prefill_req_count == 0

N_decode = target_dims.decode_req_count
N_prefill = target_dims.prefill_req_count
N_total = N_decode + N_prefill
T = target_dims.token_count
dummy_block_idx = self.block_allocator.dummy_block_idx

# 1. Request counts and token count (decode-only: 1 token per request).
self.total_request_count = N
self.active_token_count = N
self.num_prefill_requests = 0
# 1. Request counts and token count.
self.total_request_count = N_total
self.active_token_count = T
self.num_prefill_requests = N_prefill

# 2. Per-request state consumed by mha_metadata.update().
self.request_query_lengths[0:N].fill_(1)
self.request_kv_length_offsets[0:N].fill_(0)
self.request_to_kv_block_ids[0:N, 0] = dummy_block_idx
# Decode requests: 1 token each.
self.request_query_lengths[0:N_decode].fill_(1)
# Prefill requests: distribute remaining tokens evenly.
if N_prefill > 0:
prefill_tokens = T - N_decode
tokens_per_prefill = prefill_tokens // N_prefill
remainder = prefill_tokens % N_prefill
self.request_query_lengths[N_decode:N_total].fill_(tokens_per_prefill)
# Give remainder tokens to the last prefill request.
if remainder > 0:
self.request_query_lengths[N_total - 1] += remainder

self.request_kv_length_offsets[0:N_total].fill_(0)
self.request_to_kv_block_ids[0:N_total, 0] = dummy_block_idx

# 3. Token-level state consumed by the triton KV append kernel.
self.token_to_block_idx[0:N] = dummy_block_idx
self.token_to_local_position_within_kv_block[0:N] = 0
self.token_to_block_idx[0:T] = dummy_block_idx
self.token_to_local_position_within_kv_block[0:T] = 0

if self.is_hybrid_model:
# 4. token_to_request_idx: needed by mamba_metadata.update() for hybrid models.
self.token_to_request_idx[0:N] = torch.arange(
0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype
# Decode tokens: 1:1 mapping to decode requests.
self.token_to_request_idx[0:N_decode] = torch.arange(
0,
N_decode,
device=self.token_to_request_idx.device,
dtype=self.token_to_request_idx.dtype,
)
# Prefill tokens: distribute among prefill requests.
if N_prefill > 0:
prefill_tokens = T - N_decode
tokens_per_prefill = prefill_tokens // N_prefill
remainder = prefill_tokens % N_prefill
token_offset = N_decode
for i in range(N_prefill):
qlen = tokens_per_prefill + (remainder if i == N_prefill - 1 else 0)
self.token_to_request_idx[token_offset : token_offset + qlen] = N_decode + i
token_offset += qlen

# 5. Mamba state: allocate slots for dummy requests.
self.mamba_metadata.request_to_mamba_state_idx[0:N] = (
self.mamba_metadata.batch_allocate_slots(N)
self.mamba_metadata.request_to_mamba_state_idx[0:N_total] = (
self.mamba_metadata.batch_allocate_slots(N_total)
)

def initialize_attention_state(
Expand All @@ -1343,9 +1381,14 @@ def initialize_attention_state(
self.is_creating_cuda_graphs and is_expert_parallel_dummy_cuda_graph_step
), "Dummy expert model parallel steps should not be creating cuda graphs."

# If in CUDA graph creation mode, add dummy requests for CUDA graph capture
# If in CUDA graph creation mode, add dummy requests for CUDA graph capture.
# For EP dummy steps, we delay the dummy setup until after graph matching
# so we can match the selected graph's prefill/decode split. This is critical
# for Mamba hybrid models: if the matched graph has prefill requests but the
# dummy rank has none, cu_seqlens will be all-zeros, causing illegal memory
# access in the Mamba SSM prefill kernel.
if is_expert_parallel_dummy_cuda_graph_step:
self.add_dummy_requests_for_expert_parallel_step()
pass # Dummy setup deferred to after match_graph_config (below).
elif self.is_creating_cuda_graphs:
self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions)

Expand Down Expand Up @@ -1378,6 +1421,17 @@ def initialize_attention_state(
# will directly call the model forward pass with a single token.
return

if is_expert_parallel_dummy_cuda_graph_step:
# Deferred dummy setup: now that we know the matched graph, set up
# dummy requests to match its prefill/decode split.
self.add_dummy_requests_for_expert_parallel_step(best_graph)
batch_dimensions = InferenceBatchDimensions(
token_count=self.active_token_count,
prefill_req_count=self.num_prefill_requests,
decode_req_count=self.num_decode_requests,
)
self.batch_dimensions = batch_dimensions

if self.using_cuda_graph_this_step():
self.padded_batch_dimensions = best_graph
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,33 @@ def test_update_padded_prefill_and_decode(self, metadata_context):
expected_seq_idx[:, :10] = 0
assert torch.equal(metadata_context.seq_idx, expected_seq_idx)

# -------------------------------------------------------------------------
# Scenario 3b: Padded prefill without real prefill (EP dummy rank bug)
# -------------------------------------------------------------------------

@pytest.mark.internal
def test_update_rejects_padded_prefill_without_real_prefill(self, metadata_context):
"""Padded prefill > 0 with real prefill == 0 must raise an assertion.

This scenario can happen on EP dummy ranks when the matched CUDA graph has
prefill slots but the dummy rank has no real prefill requests. The Mamba SSM
kernel would crash with an illegal memory access due to all-zero cu_seqlens.
"""
seq_lengths = [1, 1] # 2 decode requests, 0 prefill
num_decode = 2
padded_dims = InferenceBatchDimensions(
token_count=32, prefill_req_count=2, decode_req_count=2
)

with pytest.raises(AssertionError, match="Mamba models require real prefill requests"):
self._run_update_test(
metadata_context,
seq_lengths,
num_decode,
padded_dims,
enable_chunked_prefill=False,
)

# -------------------------------------------------------------------------
# Scenario 4: Chunked Prefill
# -------------------------------------------------------------------------
Expand Down
78 changes: 78 additions & 0 deletions tests/unit_tests/inference/contexts/test_dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,84 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path(
assert (slow_mamba >= 0).all(), "slow path should allocate valid mamba slots"
assert fast_mamba.unique().numel() == N, "fast path mamba slots must be unique"

@pytest.mark.internal
@pytest.mark.parametrize("is_hybrid_model", [True])
@pytest.mark.parametrize("num_cuda_graphs", [4, -1])
def test_add_dummy_requests_for_expert_parallel_step_with_prefill_graph(
self, is_hybrid_model: bool, num_cuda_graphs: int
):
"""When target_graph has prefill requests, the dummy setup must create
matching prefill requests with valid token distributions. This prevents
all-zero cu_seqlens on EP dummy ranks for Mamba hybrid models.
"""
self._setup_model_parallel_group(1, 1)

ctx = self._get_dynamic_context(
params_dtype=torch.float32,
num_layers=4,
kv_channels=8,
num_attention_heads=2,
max_sequence_length=512,
buffer_size_gb=0.03,
block_size_tokens=128,
max_tokens=None,
is_hybrid_model=is_hybrid_model,
layer_type_list=[Symbols.MAMBA, Symbols.ATTENTION, Symbols.MLP, Symbols.ATTENTION],
num_cuda_graphs=num_cuda_graphs,
)

# Find a mixed (prefill + decode) graph from the list
mixed_graphs = [
g for g in ctx.cuda_graph_batch_dimensions_list if g.prefill_req_count > 0
]
if not mixed_graphs:
pytest.skip("No mixed CUDA graphs available for this configuration")

target_graph = mixed_graphs[-1] # smallest mixed graph
N_decode = target_graph.decode_req_count
N_prefill = target_graph.prefill_req_count
N_total = N_decode + N_prefill
T = target_graph.token_count

ctx.add_dummy_requests_for_expert_parallel_step(target_graph)

# 1. Scalar counts match the target graph
assert ctx.total_request_count == N_total
assert ctx.active_token_count == T
assert ctx.num_prefill_requests == N_prefill
assert ctx.num_decode_requests == N_decode

# 2. Query lengths: decode=1, prefill tokens sum to T - N_decode
decode_qlens = ctx.request_query_lengths[:N_decode]
assert torch.all(decode_qlens == 1)
prefill_qlens = ctx.request_query_lengths[N_decode:N_total]
assert torch.all(prefill_qlens >= 1), "Each prefill request must have at least 1 token"
assert prefill_qlens.sum().item() == T - N_decode

# 3. Token-level state uses dummy blocks
dummy_block_idx = ctx.block_allocator.dummy_block_idx
assert torch.all(ctx.token_to_block_idx[:T] == dummy_block_idx)

# 4. Hybrid model: token_to_request_idx is valid
token_to_req = ctx.token_to_request_idx[:T]
# Decode tokens should map to requests 0..N_decode-1
for i in range(N_decode):
assert token_to_req[i].item() == i
# Prefill tokens should map to requests N_decode..N_total-1
prefill_req_ids = token_to_req[N_decode:T]
assert (prefill_req_ids >= N_decode).all()
assert (prefill_req_ids < N_total).all()
# Check each prefill request gets the right number of tokens
for i in range(N_prefill):
req_id = N_decode + i
count = (prefill_req_ids == req_id).sum().item()
assert count == ctx.request_query_lengths[req_id].item()

# 5. Mamba state slots are valid and unique
mamba_slots = ctx.mamba_metadata.request_to_mamba_state_idx[:N_total]
assert (mamba_slots >= 0).all(), "All mamba slots must be valid"
assert mamba_slots.unique().numel() == N_total, "All mamba slots must be unique"

@pytest.mark.internal
def test_gqa_high_tp_partition_heads(self):
"""Tests that TP > GQA results in 1 attention head per partition."""
Expand Down