Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 18 additions & 28 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,40 +1045,30 @@ def post_process_requests(
if not request.generated_log_probs:
request.generated_log_probs = []

# For chunked prefill with materialize_only_last_token_logits, discard intermediate log probs
if (
request_id == self.context.chunked_prefill_request_id
and self.materialize_only_last_token_logits
):
request.prompt_log_probs = []
request.generated_log_probs = []
is_chunked_prefill = request_id == self.context.chunked_prefill_request_id
is_prefill = len(request_log_probs) > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we have scheduling code to prevent prefill requests from having a single token, but would be good to double check that there's no edge case here


if request.sampling_params.skip_prompt_log_probs:
# We only want decode log probs.
if is_chunked_prefill:
pass
elif is_prefill:
request.generated_log_probs.append(request_log_probs[-1])
else:
request.generated_log_probs.extend(request_log_probs)
else:
# Split log probs between prompt and generated based on remaining prompt slots.
prompt_length = len(request.prompt_tokens)
total_accumulated = len(request.prompt_log_probs) + len(
request.generated_log_probs
)
remaining_prompt_slots = max(0, prompt_length - 1 - total_accumulated)
split_idx = min(remaining_prompt_slots, len(request_log_probs))

# Handle skip_prompt_log_probs during prefill
# If skip_prompt_log_probs is True and we have multiple log probs (prefill),
# only process the last one (first generated token)
if request.sampling_params.skip_prompt_log_probs and len(request_log_probs) > 1:
# Only append the last log prob (first generated token) to generated_log_probs
request.generated_log_probs.append(request_log_probs[-1])
else:
# Vectorized approach: calculate split point and use list slicing
if not request.sampling_params.skip_prompt_log_probs:
# Calculate how many log probs go to prompt vs generated
remaining_prompt_slots = max(0, prompt_length - 1 - total_accumulated)
split_idx = min(remaining_prompt_slots, len(request_log_probs))

# Batch extend instead of individual appends
if split_idx > 0:
request.prompt_log_probs.extend(request_log_probs[:split_idx])
if split_idx < len(request_log_probs):
request.generated_log_probs.extend(request_log_probs[split_idx:])
else:
# All log probs go to generated
request.generated_log_probs.extend(request_log_probs)
if split_idx > 0:
request.prompt_log_probs.extend(request_log_probs[:split_idx])
if split_idx < len(request_log_probs):
request.generated_log_probs.extend(request_log_probs[split_idx:])

# Process top_n_logprobs if available (unified for both regular and chunked prefill)
if top_n_logprobs is not None and req_idx in top_n_logprobs:
Expand Down
49 changes: 37 additions & 12 deletions tests/unit_tests/inference/engines/test_dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,26 +1534,45 @@ def test_chunked_prefill_avoid_single_token_chunk(self):
@pytest.mark.skipif(
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
)
@pytest.mark.parametrize("materialize_only_last_token_logits", [True, False])
@pytest.mark.parametrize("skip_prompt_log_probs", [True, False])
@torch.inference_mode()
def test_chunked_prefill_with_log_probs(self):
def test_chunked_prefill_with_log_probs(
self, materialize_only_last_token_logits: bool, skip_prompt_log_probs: bool
):
"""
Test that chunked prefill correctly handles log probs with materialize_only_last_token_logits.
This verifies that intermediate log probs are properly discarded during chunked prefill.
Test that chunked prefill correctly handles log probs across all branches
of the log-prob accumulation logic.
When materialize_only_last_token_logits=True, skip_prompt_log_probs must be True.
"""
if materialize_only_last_token_logits and not skip_prompt_log_probs:
with pytest.raises(AssertionError, match="only last token logits are materialized"):
self._run_test(
num_requests=1,
min_prompt_length=1200,
max_prompt_length=1200,
num_tokens_to_generate=8,
materialize_only_last_token_logits=True,
return_log_probs=True,
skip_prompt_log_probs=False,
model_provider="gpt",
context_block_size_tokens=256,
context_max_tokens=1000,
enable_chunked_prefill=True,
)
return

prompt_length = 1200
num_tokens_to_generate = 8

# Run with chunked prefill, materialize_only_last_token_logits=True, and skip_prompt_log_probs=True
# This is the only valid combination for chunked prefill with last-token-only logits
env = self._run_test(
num_requests=1,
min_prompt_length=prompt_length,
max_prompt_length=prompt_length,
num_tokens_to_generate=num_tokens_to_generate,
materialize_only_last_token_logits=True,
materialize_only_last_token_logits=materialize_only_last_token_logits,
return_log_probs=True,
skip_prompt_log_probs=True,
skip_prompt_log_probs=skip_prompt_log_probs,
model_provider="gpt",
context_block_size_tokens=256,
context_max_tokens=1000,
Expand All @@ -1574,11 +1593,17 @@ def test_chunked_prefill_with_log_probs(self):
f"generated log probs, got {len(request.generated_log_probs)}"
)

# When skip_prompt_log_probs is True, prompt_log_probs should be empty
assert request.prompt_log_probs is None or len(request.prompt_log_probs) == 0, (
f"Request {request.request_id}: prompt_log_probs should be empty when "
f"skip_prompt_log_probs=True, but got {len(request.prompt_log_probs) if request.prompt_log_probs else 0} items"
)
if skip_prompt_log_probs:
assert request.prompt_log_probs is None or len(request.prompt_log_probs) == 0, (
f"Request {request.request_id}: prompt_log_probs should be empty when "
f"skip_prompt_log_probs=True, but got "
f"{len(request.prompt_log_probs) if request.prompt_log_probs else 0} items"
)
else:
assert len(request.prompt_log_probs) == prompt_length - 1, (
f"Request {request.request_id}: Expected {prompt_length - 1} "
f"prompt log probs, got {len(request.prompt_log_probs)}"
)

# Validate each generated log prob
for i, log_prob in enumerate(request.generated_log_probs):
Expand Down
Loading