[Model Runner V2] Introduce num_tokens_for_attn#36815
[Model Runner V2] Introduce num_tokens_for_attn#36815WoosukKwon wants to merge 1 commit intomainfrom
Conversation
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
|
Hi @WoosukKwon, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new field, num_tokens_for_attn, to BatchExecutionDescriptor and InputBatch to allow for a more precise specification of the number of tokens used in attention mechanisms, particularly for CUDA graphs. The changes are consistently applied across cudagraph_utils, dp_utils, model_runner, and model_states to propagate and utilize this new field. The implementation appears correct and well-integrated.
There was a problem hiding this comment.
Code Review
This pull request introduces a new field, num_tokens_for_attn, to BatchExecutionDescriptor and InputBatch. This field is used to specify the exact number of tokens that should be processed by the attention mechanism, which can differ from the total number of tokens in a batch, particularly when using CUDA graphs for decode operations. The changes are consistently propagated through cudagraph_utils.py, dp_utils.py, input_batch.py, model_runner.py, and model_states/default.py. This refactoring centralizes the logic for determining the attention token count and simplifies the prepare_attn function. The implementation appears correct and robust, with no high or critical issues found.
njhill
left a comment
There was a problem hiding this comment.
we probably wanna rename these:
num_actual_tokens <= num_attn_tokens <= num_input_tokens
agree and let's document the above relationship explicitly in comment too :)
| # i.e. no request padding is needed | ||
| # so we leave it as None |
There was a problem hiding this comment.
save a line?
| # i.e. no request padding is needed | |
| # so we leave it as None | |
| # i.e. no request padding is needed, so we leave it as None |
| num_tokens: int | ||
| num_tokens_for_attn: int | None | ||
| num_reqs: int | None # None means no request padding is needed (PIECEWISE graphs) | ||
| uniform_token_count: int | None = None |
There was a problem hiding this comment.
I think we should add more doc to these fields
including meaning of None for the other ones too
| if batch_desc.num_tokens_for_attn is not None: | ||
| num_tokens_for_attn = batch_desc.num_tokens_for_attn | ||
| else: | ||
| num_tokens_for_attn = num_tokens |
There was a problem hiding this comment.
could simplify
| if batch_desc.num_tokens_for_attn is not None: | |
| num_tokens_for_attn = batch_desc.num_tokens_for_attn | |
| else: | |
| num_tokens_for_attn = num_tokens | |
| num_tokens_for_attn = batch_desc.num_tokens_for_attn or num_tokens |
|
Also would be good to make sure the CI tests cover this |
No description provided.