Skip to content

Commit 72184bf

Browse files
committed
Fix Comment
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
1 parent d9bd46f commit 72184bf

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -675,11 +675,9 @@ def _general_warmup(self,
675675
self.kv_cache_manager_key)
676676
token_num_upper_bound = min(self.max_num_tokens,
677677
self.batch_size * (self.max_seq_len - 1))
678-
curr_max_num_tokens = min(
679-
kv_cache_manager.get_num_available_tokens(
680-
token_num_upper_bound=token_num_upper_bound,
681-
max_num_draft_tokens=self.original_max_draft_len),
682-
token_num_upper_bound)
678+
curr_max_num_tokens = kv_cache_manager.get_num_available_tokens(
679+
token_num_upper_bound=token_num_upper_bound,
680+
max_num_draft_tokens=self.original_max_draft_len)
683681
max_batch_size = min(
684682
self.batch_size,
685683
curr_max_num_tokens // (1 + self.runtime_draft_len))
@@ -730,11 +728,9 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
730728
self.kv_cache_manager_key)
731729
token_num_upper_bound = min(self.max_num_tokens,
732730
self.batch_size * (self.max_seq_len - 1))
733-
curr_max_num_tokens = min(
734-
kv_cache_manager.get_num_available_tokens(
735-
token_num_upper_bound=token_num_upper_bound,
736-
max_num_draft_tokens=self.original_max_draft_len),
737-
token_num_upper_bound)
731+
curr_max_num_tokens = kv_cache_manager.get_num_available_tokens(
732+
token_num_upper_bound=token_num_upper_bound,
733+
max_num_draft_tokens=self.original_max_draft_len)
738734

739735
cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None)
740736
with self.no_cuda_graph(), autotune(cache_path=cache_path):

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,10 +1002,13 @@ def get_num_kv_blocks(self, num_tokens: int) -> int:
10021002
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
10031003

10041004
def get_num_available_tokens(self,
1005+
token_num_upper_bound: int,
10051006
max_num_draft_tokens: int = 0,
10061007
**kwargs) -> int:
1007-
return (self.get_num_free_blocks() * self.tokens_per_block -
1008-
self.num_extra_kv_tokens - max_num_draft_tokens)
1008+
return min(
1009+
token_num_upper_bound,
1010+
self.get_num_free_blocks() * self.tokens_per_block -
1011+
self.num_extra_kv_tokens - max_num_draft_tokens)
10091012

10101013
def get_buffers(self,
10111014
layer_idx: int,

0 commit comments

Comments
 (0)