Skip to content

Conversation

@raulchen
Copy link
Contributor

Summary

  • Add last_token_logits_only parameter to Llama3 and Qwen3 models to avoid materializing the full [B, T, V] logits tensor dur
    ing prefill
  • When prompt_logprobs=False (the common case), only compute logits for the last token
  • Add parametrized tests for both models verifying output shape and generation equivalence

Motivation

During prefill, only the last token's logits are needed to start decoding. Computing logits for all prompt tokens requires a
[B, T, V] matmul where V (vocab size) is typically 32K-128K. This is wasteful when prompt_logprobs is not requested.

This optimization benefits:

  • Standard inference/chat (most common)
  • RL training (only generation logprobs needed, not prompt logprobs)

Test plan

  • test_last_token_logits_only[llama3] - verifies output shape and generation equivalence
  • test_last_token_logits_only[qwen3] - verifies output shape and generation equivalence
  • Existing generator tests pass

raulchen and others added 3 commits January 14, 2026 12:00
During prefill, only the last token's logits are needed to start
decoding. This optimization avoids materializing the full [B, T, V]
logits tensor when prompt_logprobs is not requested.
Add parametrized test in test_models_common.py that verifies
both llama3 and qwen3 models produce correct output shape and
matching logits when using last_token_logits_only=True. Also
tests generation equivalence with and without prompt_logprobs.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable optimization to skip the computation of full logits during the prefill stage when prompt log probabilities are not required. The changes are implemented by adding a last_token_logits_only parameter to the Llama3 and Qwen3 models and leveraging it in the generator's prefill logic. The implementation is clean and the logic is sound. A new parametrized test file is added, which thoroughly validates the new functionality by checking output shapes, value equivalence, and generation consistency. I have one suggestion to improve the clarity of the new test code.

Comment on lines 37 to 41
base_config = (
AutoConfig.from_pretrained(model_name)
if config_cls == Llama3Config
else PretrainedConfig.from_pretrained(model_name)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conditional logic for loading the base configuration is a bit confusing and contains unreachable code. Both Llama3Config and Qwen3Config are aliases for ModelConfig, so the condition config_cls == Llama3Config will be true for both parameterized test cases, making the else branch dead code.

To simplify and improve clarity, you can consistently use AutoConfig.from_pretrained(model_name), which is the standard approach for loading configurations from the HuggingFace Hub.

        base_config = AutoConfig.from_pretrained(model_name)

@raulchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable optimization to skip computing full logits during prefill when they are not needed, controlled by a new last_token_logits_only parameter. The implementation is clean and correctly applied to both Llama3 and Qwen3 models. The logic is correctly wired into the generator utility, activating the optimization when prompt_logprobs is false. The addition of a new parameterized test in test_models_common.py is excellent, as it thoroughly verifies the correctness of the change by checking output shapes, value equivalence, and end-to-end generation consistency. I have one suggestion to improve the new test for better failure diagnostics.

raulchen and others added 2 commits January 14, 2026 14:44
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@pcmoritz pcmoritz added the tx label Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants