Skip to content

Commit 30fcc38

Browse files
authored
fix time logging and other small things (#590)
Addressing some feedback to #552: - fix time logging - removing unnecessary (duplicated) code - fix typing scheduler class - fix some typos --------- Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
1 parent a6e84e5 commit 30fcc38

File tree

3 files changed

+17
-18
lines changed

3 files changed

+17
-18
lines changed

tests/scheduling_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import os
33
from collections import defaultdict, deque
4-
from typing import Any
4+
from typing import Any, Union
55

66
import pytest
77
from llm_cache import get_cached_engine
@@ -13,7 +13,8 @@
1313
from vllm.v1.engine import EngineCoreRequest
1414
from vllm.v1.engine.core import EngineCore
1515

16-
from vllm_spyre.v1.core.scheduler import ContinuousBatchingSpyreScheduler
16+
from vllm_spyre.v1.core.scheduler import (ChunkedPrefillSpyreScheduler,
17+
ContinuousBatchingSpyreScheduler)
1718

1819
DISABLE_ASSERTS = False # used for debugging
1920

@@ -162,7 +163,8 @@ def check_scheduler_inference_steps(
162163
available_blocks=available_blocks,
163164
backend=backend,
164165
monkeypatch=monkeypatch)
165-
scheduler: ContinuousBatchingSpyreScheduler = engine_core.scheduler
166+
scheduler: Union[ContinuousBatchingSpyreScheduler,
167+
ChunkedPrefillSpyreScheduler] = engine_core.scheduler
166168

167169
tokenizer = get_tokenizer(model.name, revision=model.revision)
168170
# clear the cache of function scheduler.check_batch_tkv_limit()

vllm_spyre/platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
211211
f" be divisible by the block size ({cls._block_size}) "
212212
"to enable chunked prefill. It was set to "
213213
f"`{scheduler_config.max_num_batched_tokens}`. Please "
214-
"set `--max-num-batched-tokens` to a number that satisfy "
214+
"set `--max-num-batched-tokens` to a number that satisfies "
215215
"this constraint.")
216216

217217
logger.info(

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,7 +1851,7 @@ def _prepare_chunked_prefill(self, req_id: str):
18511851
18521852
# Case III
18531853
1854-
No left paddings and more than one chunk
1854+
No left padding and more than one chunk
18551855
18561856
13 tokens
18571857
4 blocks
@@ -1862,8 +1862,8 @@ def _prepare_chunked_prefill(self, req_id: str):
18621862
18631863
NOTE: The goal of this "illustration" is to depics strategies to write
18641864
code to create the chunks, not necessarily enumerate the possible
1865-
scenario. Of course there are interpretations where these cases
1866-
overlaps.
1865+
scenarios. Of course there are interpretations where these cases
1866+
overlap.
18671867
18681868
'''
18691869

@@ -2105,7 +2105,7 @@ def add_new_request(self, request: NewRequestData):
21052105
new_tokens = (sampling_params.max_tokens
21062106
if sampling_params is not None else 0)
21072107
total_tokens = prompt_len + new_tokens - 1
2108-
# subtract the padding blocks from the reserved blocks
2108+
# calculate the number of reserved blocks
21092109
n_reserved_blocks = math.ceil(total_tokens / self.block_size)
21102110

21112111
self.req_ids2num_reserved_blocks[req_id] = n_reserved_blocks
@@ -2216,8 +2216,7 @@ def check_incomplete_prefill(self, scheduler_output: SchedulerOutput):
22162216
return False
22172217

22182218
# possible prefill
2219-
req_id = new_reqs[0].req_id if\
2220-
len(new_reqs) == 1 else \
2219+
req_id = new_reqs[0].req_id if len(new_reqs) == 1 else \
22212220
cached_reqs.req_ids[0]
22222221

22232222
num_scheduled_tokens =\
@@ -2302,6 +2301,9 @@ def execute_model(
23022301
if not self.is_driver_worker:
23032302
return self.get_empty_output()
23042303

2304+
t1 = time.time() - t0
2305+
logger.debug("t_forward_pass: %.2fms [prefill single chunk]",
2306+
(t1 * 1000))
23052307
return CPSpyreModelRunnerOutput(
23062308
req_ids=list(req_id_to_index.keys()),
23072309
req_id_to_index=req_id_to_index,
@@ -2319,19 +2321,14 @@ def execute_model(
23192321
sampling_metadata=self.get_sampling_metadata(is_prefill),
23202322
)
23212323
t1 = time.time() - t0
2322-
logger.debug("t_token: %.2fms", (t1 * 1000))
2323-
2324-
# Add the sampled token(s) to the request cache
2325-
req_ids = ([r.req_id for r in scheduler_output.scheduled_new_reqs]
2326-
if len(scheduler_output.scheduled_new_reqs) > 0 \
2327-
else self.input_batch.sorted_requests_ids)
2324+
step_type = "[prefill last chunk]" if is_prefill else "[decode]"
2325+
logger.debug("t_token: %.2fms %s", (t1 * 1000), step_type)
23282326

23292327
# Get the right batch, if this is the last chunk to conclude the
23302328
# prefill, we'll generate a token and we should get from the prefill
23312329
# batch because input_batch may have other request that are were
23322330
# not processed at this step.
2333-
batch = self.prefill_batch if is_prefill \
2334-
else self.input_batch
2331+
batch = self.prefill_batch if is_prefill else self.input_batch
23352332

23362333
# Add the sampled token(s) to the request cache
23372334
req_ids = ([r.req_id for r in scheduler_output.scheduled_new_reqs]

0 commit comments

Comments
 (0)