Skip to content

[Model Runner V2] Introduce num_tokens_for_attn#36815

Open
WoosukKwon wants to merge 1 commit intomainfrom
woosuk/mrv2-cudagraph-attn-fix
Open

[Model Runner V2] Introduce num_tokens_for_attn#36815
WoosukKwon wants to merge 1 commit intomainfrom
woosuk/mrv2-cudagraph-attn-fix

Conversation

@WoosukKwon
Copy link
Collaborator

No description provided.

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
@mergify
Copy link

mergify bot commented Mar 11, 2026

Hi @WoosukKwon, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new field, num_tokens_for_attn, to BatchExecutionDescriptor and InputBatch to allow for a more precise specification of the number of tokens used in attention mechanisms, particularly for CUDA graphs. The changes are consistently applied across cudagraph_utils, dp_utils, model_runner, and model_states to propagate and utilize this new field. The implementation appears correct and well-integrated.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new field, num_tokens_for_attn, to BatchExecutionDescriptor and InputBatch. This field is used to specify the exact number of tokens that should be processed by the attention mechanism, which can differ from the total number of tokens in a batch, particularly when using CUDA graphs for decode operations. The changes are consistently propagated through cudagraph_utils.py, dp_utils.py, input_batch.py, model_runner.py, and model_states/default.py. This refactoring centralizes the logic for determining the attention token count and simplifies the prepare_attn function. The implementation appears correct and robust, with no high or critical issues found.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably wanna rename these: num_actual_tokens <= num_attn_tokens <= num_input_tokens

agree and let's document the above relationship explicitly in comment too :)

Comment on lines +141 to +142
# i.e. no request padding is needed
# so we leave it as None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save a line?

Suggested change
# i.e. no request padding is needed
# so we leave it as None
# i.e. no request padding is needed, so we leave it as None

Comment on lines 36 to 39
num_tokens: int
num_tokens_for_attn: int | None
num_reqs: int | None # None means no request padding is needed (PIECEWISE graphs)
uniform_token_count: int | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add more doc to these fields

including meaning of None for the other ones too

Comment on lines +604 to +607
if batch_desc.num_tokens_for_attn is not None:
num_tokens_for_attn = batch_desc.num_tokens_for_attn
else:
num_tokens_for_attn = num_tokens
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could simplify

Suggested change
if batch_desc.num_tokens_for_attn is not None:
num_tokens_for_attn = batch_desc.num_tokens_for_attn
else:
num_tokens_for_attn = num_tokens
num_tokens_for_attn = batch_desc.num_tokens_for_attn or num_tokens

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 11, 2026
@njhill
Copy link
Member

njhill commented Mar 11, 2026

Also would be good to make sure the CI tests cover this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Ready

Development

Successfully merging this pull request may close these issues.

2 participants