diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 49fa40f1d6c..7e0d8ab2511 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -127,6 +127,16 @@ def update( has_chunked_prefill_req = enable_chunked_prefill and real_prefill_count > 0 + # Mamba models cannot have padded prefill requests without real prefill requests. + # If this happens (e.g. on EP dummy ranks with incorrect setup), cu_seqlens will + # be all zeros, causing an illegal memory access in the Mamba SSM prefill kernel. + assert not (padded_prefill_count > 0 and real_prefill_count == 0), ( + f"Mamba models require real prefill requests when padded prefill count > 0. " + f"Real prefill: {real_prefill_count}, Padded prefill: {padded_prefill_count}. " + f"This can occur on EP dummy ranks if they are not properly set up to match " + f"the selected CUDA graph's prefill/decode split." + ) + # Although the context ensures that the last request is always the designated # chunked prefill request, what we actually care about is ensuring that any # prefill request with non-zero initial states is executed through the diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index cd55882db6b..e2539ea8e6e 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1346,7 +1346,9 @@ def num_decode_requests(self) -> int: """ return self.total_request_count - self.paused_request_count - self.num_prefill_requests - def add_dummy_requests_for_expert_parallel_step(self) -> None: + def add_dummy_requests_for_expert_parallel_step( + self, target_graph: Optional[InferenceBatchDimensions] = None + ) -> None: """Minimal context setup so an EP rank with no real requests can replay an already-captured cuda graph without crashing or corrupting memory. @@ -1356,50 +1358,117 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: We setup minimal state such the initialize_attention_state and the forward pass can run without error. + Args: + target_graph: The matched CUDA graph dimensions to set up dummy requests for. + If None, uses the smallest decode-only graph. When provided, the dummy + setup will match the graph's prefill/decode split, which is critical for + Mamba hybrid models to avoid all-zero cu_seqlens on EP dummy ranks. """ - smallest_cuda_graph_dimensions = min( - [x for x in self.cuda_graph_batch_dimensions_list if x.prefill_req_count == 0] - ) - # the smallest cuda graph is decode only. - assert smallest_cuda_graph_dimensions.prefill_req_count == 0 - - N = smallest_cuda_graph_dimensions.decode_req_count - tokens_per_request = self.num_speculative_tokens + 1 - T = smallest_cuda_graph_dimensions.token_count # N * tokens_per_request + if target_graph is not None: + target_dims = target_graph + else: + target_dims = min( + [x for x in self.cuda_graph_batch_dimensions_list if x.prefill_req_count == 0] + ) + # the smallest cuda graph is decode only. + assert target_dims.prefill_req_count == 0 + + N_decode = target_dims.decode_req_count + N_prefill = target_dims.prefill_req_count + N_total = N_decode + N_prefill + T = target_dims.token_count + tokens_per_decode_request = self.num_speculative_tokens + 1 dummy_block_idx = self.block_allocator.dummy_block_idx + # Token distribution scheme: + # - Decode requests each get exactly `tokens_per_decode_request` tokens + # (i.e. 1 + num_speculative_tokens). + # - The remaining tokens (T - N_decode * tokens_per_decode_request) are + # divided evenly across prefill requests, with any remainder added to + # the last prefill request. + # This mirrors the distribution in add_dummy_requests_for_cudagraph_capture + # (which distributes remainder starting from the first request instead) but + # the exact per-request sizes don't matter here — only the totals need to + # match the CUDA graph dimensions so the captured graph replays correctly. + # 1. Request counts and token count. # With speculative decoding each decode request has (num_speculative_tokens + 1) tokens. - self.total_request_count = N + self.total_request_count = N_total self.active_token_count = T - self.num_prefill_requests = 0 + self.num_prefill_requests = N_prefill # 2. Per-request state consumed by mha_metadata.update(). - self.request_query_lengths[0:N].fill_(tokens_per_request) - self.request_kv_length_offsets[0:N].fill_(0) - self.request_to_kv_block_ids[0:N, 0] = dummy_block_idx + if N_decode > 0: + self.request_query_lengths[0:N_decode] = tokens_per_decode_request + + # Prefill requests: distribute remaining tokens evenly. + if N_prefill > 0: + prefill_tokens = T - (N_decode * tokens_per_decode_request) + tokens_per_prefill = prefill_tokens // N_prefill + remainder = prefill_tokens % N_prefill + self.request_query_lengths[N_decode:N_total] = tokens_per_prefill + # Give remainder tokens to the last prefill request. + if remainder > 0: + self.request_query_lengths[N_total - 1] += remainder + + self.request_kv_length_offsets[0:N_total] = 0 + self.request_to_kv_block_ids[0:N_total, 0] = dummy_block_idx # 3. Token-level state consumed by the triton KV append kernel. self.token_to_block_idx[0:T] = dummy_block_idx - self.token_to_local_position_within_kv_block[0:T] = ( - torch.arange(T, device=self.token_to_block_idx.device) % tokens_per_request - ) + + decode_tokens_total = N_decode * tokens_per_decode_request + if decode_tokens_total > 0: + self.token_to_local_position_within_kv_block[0:decode_tokens_total] = ( + torch.arange(decode_tokens_total, device=self.token_to_block_idx.device) + % tokens_per_decode_request + ) + + if N_prefill > 0: + prefill_tokens_total = T - decode_tokens_total + self.token_to_local_position_within_kv_block[decode_tokens_total:T] = ( + torch.arange(prefill_tokens_total, device=self.token_to_block_idx.device) + % self.block_size_tokens + ) if self.is_hybrid_model: # 4. token_to_request_idx: needed by mamba_metadata.update() for hybrid models. - self.token_to_request_idx[0:T] = torch.repeat_interleave( - torch.arange( - 0, - N, - device=self.token_to_request_idx.device, + if N_decode > 0: + self.token_to_request_idx[0:decode_tokens_total] = torch.repeat_interleave( + torch.arange( + 0, + N_decode, + device=self.token_to_request_idx.device, + dtype=self.token_to_request_idx.dtype, + ), + tokens_per_decode_request, + ) + + # Prefill tokens: distribute among prefill requests. + if N_prefill > 0: + prefill_tokens = T - decode_tokens_total + tokens_per_prefill = prefill_tokens // N_prefill + remainder = prefill_tokens % N_prefill + # Build repeat counts on CPU, then do a single write to GPU. + repeat_counts = torch.full( + (N_prefill,), tokens_per_prefill, dtype=torch.long, device="cpu" + ) + if remainder > 0: + repeat_counts[-1] += remainder + request_ids = torch.arange( + N_decode, + N_decode + N_prefill, dtype=self.token_to_request_idx.dtype, - ), - tokens_per_request, - ) + device="cpu", + ) + token_request_indices = torch.repeat_interleave(request_ids, repeat_counts) + self.token_to_request_idx[ + decode_tokens_total : decode_tokens_total + prefill_tokens + ] = token_request_indices.to(device=self.token_to_request_idx.device) # 5. Mamba state: allocate slots for dummy requests. - self.mamba_metadata.request_to_mamba_state_idx[0:N] = ( - self.mamba_metadata.batch_allocate_slots(N) + self.mamba_metadata.request_to_mamba_state_idx[0:N_total] = ( + self.mamba_metadata.batch_allocate_slots(N_total) ) def initialize_attention_state( @@ -1423,9 +1492,14 @@ def initialize_attention_state( self.is_creating_cuda_graphs and is_expert_parallel_dummy_cuda_graph_step ), "Dummy expert model parallel steps should not be creating cuda graphs." - # If in CUDA graph creation mode, add dummy requests for CUDA graph capture + # If in CUDA graph creation mode, add dummy requests for CUDA graph capture. + # For EP dummy steps, we delay the dummy setup until after graph matching + # so we can match the selected graph's prefill/decode split. This is critical + # for Mamba hybrid models: if the matched graph has prefill requests but the + # dummy rank has none, cu_seqlens will be all-zeros, causing illegal memory + # access in the Mamba SSM prefill kernel. if is_expert_parallel_dummy_cuda_graph_step: - self.add_dummy_requests_for_expert_parallel_step() + pass # Dummy setup deferred to after match_graph_config (below). elif self.is_creating_cuda_graphs: self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) @@ -1458,6 +1532,17 @@ def initialize_attention_state( # will directly call the model forward pass with a single token. return + if is_expert_parallel_dummy_cuda_graph_step: + # Deferred dummy setup: now that we know the matched graph, set up + # dummy requests to match its prefill/decode split. + self.add_dummy_requests_for_expert_parallel_step(best_graph) + batch_dimensions = InferenceBatchDimensions( + token_count=self.active_token_count, + prefill_req_count=self.num_prefill_requests, + decode_req_count=self.num_decode_requests, + ) + self.batch_dimensions = batch_dimensions + if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph else: diff --git a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py index 697beb867e0..44117f37ee7 100644 --- a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py +++ b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py @@ -306,6 +306,29 @@ def test_update_padded_prefill_and_decode(self, metadata_context): expected_seq_idx[:, :10] = 0 assert torch.equal(metadata_context.seq_idx, expected_seq_idx) + # ------------------------------------------------------------------------- + # Scenario 3b: Padded prefill without real prefill (EP dummy rank bug) + # ------------------------------------------------------------------------- + + @pytest.mark.internal + def test_update_rejects_padded_prefill_without_real_prefill(self, metadata_context): + """Padded prefill > 0 with real prefill == 0 must raise an assertion. + + This scenario can happen on EP dummy ranks when the matched CUDA graph has + prefill slots but the dummy rank has no real prefill requests. The Mamba SSM + kernel would crash with an illegal memory access due to all-zero cu_seqlens. + """ + seq_lengths = [1, 1] # 2 decode requests, 0 prefill + num_decode = 2 + padded_dims = InferenceBatchDimensions( + token_count=32, prefill_req_count=2, decode_req_count=2 + ) + + with pytest.raises(AssertionError, match="Mamba models require real prefill requests"): + self._run_update_test( + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False + ) + # ------------------------------------------------------------------------- # Scenario 4: Chunked Prefill # ------------------------------------------------------------------------- diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index c9aa5341c05..a24c7a28999 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -1501,6 +1501,82 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( assert (slow_mamba >= 0).all(), "slow path should allocate valid mamba slots" assert fast_mamba.unique().numel() == N, "fast path mamba slots must be unique" + @pytest.mark.internal + @pytest.mark.parametrize("is_hybrid_model", [True]) + @pytest.mark.parametrize("num_cuda_graphs", [4, -1]) + def test_add_dummy_requests_for_expert_parallel_step_with_prefill_graph( + self, is_hybrid_model: bool, num_cuda_graphs: int + ): + """When target_graph has prefill requests, the dummy setup must create + matching prefill requests with valid token distributions. This prevents + all-zero cu_seqlens on EP dummy ranks for Mamba hybrid models. + """ + self._setup_model_parallel_group(1, 1) + + ctx = self._get_dynamic_context( + params_dtype=torch.float32, + num_layers=4, + kv_channels=8, + num_attention_heads=2, + max_sequence_length=512, + buffer_size_gb=0.03, + block_size_tokens=128, + max_tokens=None, + is_hybrid_model=is_hybrid_model, + layer_type_list=[Symbols.MAMBA, Symbols.ATTENTION, Symbols.MLP, Symbols.ATTENTION], + num_cuda_graphs=num_cuda_graphs, + ) + + # Find a mixed (prefill + decode) graph from the list + mixed_graphs = [g for g in ctx.cuda_graph_batch_dimensions_list if g.prefill_req_count > 0] + if not mixed_graphs: + pytest.skip("No mixed CUDA graphs available for this configuration") + + target_graph = mixed_graphs[-1] # smallest mixed graph + N_decode = target_graph.decode_req_count + N_prefill = target_graph.prefill_req_count + N_total = N_decode + N_prefill + T = target_graph.token_count + + ctx.add_dummy_requests_for_expert_parallel_step(target_graph) + + # 1. Scalar counts match the target graph + assert ctx.total_request_count == N_total + assert ctx.active_token_count == T + assert ctx.num_prefill_requests == N_prefill + assert ctx.num_decode_requests == N_decode + + # 2. Query lengths: decode=1, prefill tokens sum to T - N_decode + decode_qlens = ctx.request_query_lengths[:N_decode] + assert torch.all(decode_qlens == 1) + prefill_qlens = ctx.request_query_lengths[N_decode:N_total] + assert torch.all(prefill_qlens >= 1), "Each prefill request must have at least 1 token" + assert prefill_qlens.sum().item() == T - N_decode + + # 3. Token-level state uses dummy blocks + dummy_block_idx = ctx.block_allocator.dummy_block_idx + assert torch.all(ctx.token_to_block_idx[:T] == dummy_block_idx) + + # 4. Hybrid model: token_to_request_idx is valid + token_to_req = ctx.token_to_request_idx[:T] + # Decode tokens should map to requests 0..N_decode-1 + for i in range(N_decode): + assert token_to_req[i].item() == i + # Prefill tokens should map to requests N_decode..N_total-1 + prefill_req_ids = token_to_req[N_decode:T] + assert (prefill_req_ids >= N_decode).all() + assert (prefill_req_ids < N_total).all() + # Check each prefill request gets the right number of tokens + for i in range(N_prefill): + req_id = N_decode + i + count = (prefill_req_ids == req_id).sum().item() + assert count == ctx.request_query_lengths[req_id].item() + + # 5. Mamba state slots are valid and unique + mamba_slots = ctx.mamba_metadata.request_to_mamba_state_idx[:N_total] + assert (mamba_slots >= 0).all(), "All mamba slots must be valid" + assert mamba_slots.unique().numel() == N_total, "All mamba slots must be unique" + @pytest.mark.internal def test_gqa_high_tp_partition_heads(self): """Tests that TP > GQA results in 1 attention head per partition."""