Skip to content

Commit bf8c8ad

Browse files
refactor(scheduler): delay caching instead of undoing (#525)
Problem: The scheduler speculatively cached blocks during allocate_slots, then had to undo the caching (via undo_uncomputed_block_caching) in three places: spec_decode_cap trimming, prefill over-allocation for chunked prefill, and prefill preempting running decodes. This was error-prone and coupled the scheduler to KVCacheManager internals (block_pool, num_cached_block). Solution: Pass delay_cache_blocks=True to all allocate_slots calls so no blocks are cached during allocation. A single finalization loop after all scheduling decisions calls cache_blocks and schedule_sub_block_indexing for each actually-scheduled request. This eliminates undo_uncomputed_block_caching.
1 parent 226bc60 commit bf8c8ad

File tree

5 files changed

+107
-79
lines changed

5 files changed

+107
-79
lines changed

docs/sub_block_prefix_caching.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ so that each prefill does not span multiple blocks.
7373
* `RBLNKVCacheManager`: Extends upstream `KVCacheManager`.
7474
* Overrides
7575
* `allocate_slots` queues the request for sub-block indexing work
76-
to be processed by `do_pending_indexing`.
76+
when `delay_cache_blocks=False`.
77+
When `delay_cache_blocks=True`, the caller must call
78+
`schedule_sub_block_indexing()` after `cache_blocks()`.
7779
* `free` indexes all blocks (full + partial) for the finishing request
7880
and assigns a synthetic `block_hash` to the partial block.
7981
* `reset_prefix_cache` clears sub-block indices and pending indexing.
@@ -83,8 +85,9 @@ so that each prefill does not span multiple blocks.
8385
* `apply_sub_block_match` / `release_sub_block_match` consume or discard the handle.
8486
* `drain_pending_copy_ops()` retrieves the KV cache copy ops accumulated in the current scheduling step.
8587
* `release_copy_ops()` releases the source-block references after the model runner finishes copying.
86-
* `do_pending_indexing()` indexes sub-blocks for requests for which
87-
`allocate_slots` was called in the current scheduling step.
88+
* `schedule_sub_block_indexing(request)` records that a request
89+
needs sub-block indexing.
90+
* `do_pending_indexing()` executes the scheduled indexing work.
8891
Must be called after `super().update_from_output()`.
8992
* `can_use_sub_block_caching()` checks eligibility.
9093
* `KVCacheCopyOp`: Dataclass describing a sub-block KV data copy:
@@ -130,7 +133,12 @@ Sub-block indexing is **deferred** until after the forward pass writes KV data.
130133
This ensures that concurrent prefills in the same scheduling step cannot match
131134
sub-blocks whose KV data has not yet been computed and thus should not be copied.
132135

133-
During `allocate_slots`, requests are queued for deferred indexing.
136+
When a scheduler schedules a request, it should schedule sub-block indexing for that request.
137+
This is done automatically when `allocate_slots` is called with `delay_cache_blocks=False`.
138+
If `delay_cache_blocks=True`, the user must call `schedule_sub_block_indexing()` after upstream `cache_blocks()`.
139+
The current implementation of `RBLNScheduler` uses the latter approach,
140+
because its complex scheduling logic requires manual control over full block caching.
141+
134142
`RBLNScheduler.update_from_output` first calls `super().update_from_output()`
135143
(which updates `num_computed_tokens` and `free()`s finished requests),
136144
then calls `do_pending_indexing` for the remaining running requests.
@@ -170,12 +178,10 @@ produce logits).
170178
### Step 4: Copy op scheduling
171179

172180
After `allocate_slots` succeeds, the scheduler calls
173-
`apply_sub_block_match(match, request)` which, for each group:
181+
`apply_sub_block_match(match)` which, for each group:
174182
1. Looks up the destination block (newly allocated at the match boundary)
175183
2. Appends `KVCacheCopyOp(group_id, src_block_id, dst_block_id, num_tokens)`
176184

177-
(`allocate_slots` itself only queues deferred sub-block indexing.)
178-
179185
### Step 5: Copy execution (model runner)
180186

181187
The scheduler returns `RBLNSchedulerOutput` containing `kv_cache_copy_ops`.
@@ -191,7 +197,7 @@ kv_cache[:, dst_block_id, :, :, :num_tokens, :] = \
191197
### Block lifecycle
192198

193199
- **Indexing running requests**:
194-
Scheduled by `allocate_slots`, then executed by `do_pending_indexing`
200+
Scheduled by `allocate_slots` or by user, then executed by `do_pending_indexing`
195201
(called after `super().update_from_output()`).
196202
Indexes both full blocks and complete sub-blocks within partial blocks.
197203
- **Indexing finished requests**: `free()` consumes the pending-indexing

tests/torch_compile/unit/v1/core/test_prefix_caching.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,46 @@ def test_preallocation_in_prefill():
138138
)
139139

140140

141+
def test_chunked_prefill_caches_blocks_progressively():
142+
# With delay_cache_blocks + finalization, verify that each prefill
143+
# chunk caches exactly the computed blocks.
144+
block_size = 16
145+
num_blocks_per_request = 4
146+
147+
scheduler = create_scheduler(
148+
block_size=block_size,
149+
max_num_batched_tokens=block_size, # 1 block per step
150+
enable_prefix_caching=True,
151+
max_model_len=num_blocks_per_request * block_size * 2,
152+
)
153+
mgr = scheduler.kv_cache_manager
154+
155+
req = create_requests(
156+
1,
157+
num_tokens=num_blocks_per_request * block_size,
158+
max_tokens=1,
159+
same_prompt=True,
160+
)[0]
161+
scheduler.add_request(req)
162+
163+
for step in range(num_blocks_per_request):
164+
output = scheduler.schedule()
165+
166+
# All blocks are pre-allocated from step 0.
167+
blocks = mgr.get_blocks(req.request_id).blocks[0]
168+
assert len(blocks) == num_blocks_per_request
169+
170+
# After finalization: exactly (step+1) blocks should be cached.
171+
cached = [b for b in blocks if b.block_hash is not None]
172+
uncached = [b for b in blocks if b.block_hash is None]
173+
assert len(cached) == step + 1
174+
assert len(uncached) == num_blocks_per_request - step - 1
175+
176+
is_last = step == num_blocks_per_request - 1
177+
runner_out = create_runner_output(output, 0 if is_last else None)
178+
scheduler.update_from_output(output, runner_out)
179+
180+
141181
def test_preallocation_in_decode():
142182
# test that block preallocation during the decode phase
143183
# does not break prefix caching functionality

tests/torch_compile/unit/v1/core/test_sub_block_prefix_caching.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,9 +1831,10 @@ def test_multi_turn_conversation(self):
18311831
expected_scheduled = req2.num_tokens - num_computed
18321832
assert output2.num_scheduled_tokens["turn2"] == expected_scheduled
18331833

1834-
def test_speculative_alloc_undo_cleans_sub_block_index(self):
1835-
"""Speculative allocate_slots + undo_uncomputed_block_caching must not
1836-
leave stale sub-block index entries for blocks that were never computed.
1834+
def test_speculative_alloc_does_not_index_uncomputed_blocks(self):
1835+
"""Pre-allocated but uncomputed blocks must not appear in the
1836+
sub-block index. With delay_cache_blocks=True, only blocks that
1837+
are explicitly cached in the finalization step get indexed.
18371838
"""
18381839
BS = self.BLOCK_SIZE # 16
18391840
SBS = self.SUB_BLOCK_SIZE # 4
@@ -1846,9 +1847,9 @@ def test_speculative_alloc_undo_cleans_sub_block_index(self):
18461847
# 3 full blocks + 1 partial block.
18471848
# The scheduler pre-allocates blocks for ALL tokens but only computes
18481849
# one chunk (BS tokens) per iteration.
1849-
# After undo_uncomputed_block_caching:
1850+
# With delay_cache_blocks + finalization:
18501851
# block 0: computed and indexed
1851-
# blocks 1-2: full but uncomputed and unindexed
1852+
# blocks 1-2: full, never got cached/indexed
18521853
# block 3: partial, never got indexed
18531854
tokens = list(range(3 * BS + SBS))
18541855
req = _make_request("req", tokens, BS, max_tokens=1)

vllm_rbln/v1/core/rbln_kv_cache_manager.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,6 @@ def allocate_slots(
469469
delay_cache_blocks: bool = False,
470470
num_encoder_tokens: int = 0,
471471
) -> KVCacheBlocks | None:
472-
num_full_blocks_before = tuple(
473-
request.num_computed_tokens // gi.block_size for gi in self._group_infos
474-
)
475-
476472
result = super().allocate_slots(
477473
request,
478474
num_new_tokens,
@@ -484,17 +480,28 @@ def allocate_slots(
484480
num_encoder_tokens,
485481
)
486482

487-
if result is not None:
488-
# Defer sub-block indexing until after execute_model writes KV cache,
489-
# so that concurrent prefills in the same step cannot match sub-blocks
490-
# whose KV data does not yet exist.
491-
self._pending_indexing[request.request_id] = (
492-
request,
493-
num_full_blocks_before,
494-
)
483+
if result is not None and not delay_cache_blocks:
484+
# When delay_cache_blocks=True, the caller is responsible for
485+
# calling schedule_sub_block_indexing() after cache_blocks().
486+
self.schedule_sub_block_indexing(request)
495487

496488
return result
497489

490+
def schedule_sub_block_indexing(self, request: Request) -> None:
491+
"""Record that *request* needs sub-block indexing in the next
492+
``do_pending_indexing`` call.
493+
494+
When ``allocate_slots`` is called with ``delay_cache_blocks=False``,
495+
this is called automatically. Otherwise the caller must call it
496+
"""
497+
num_full_blocks_before = tuple(
498+
request.num_computed_tokens // gi.block_size for gi in self._group_infos
499+
)
500+
self._pending_indexing[request.request_id] = (
501+
request,
502+
num_full_blocks_before,
503+
)
504+
498505
def drain_pending_copy_ops(self) -> list[KVCacheCopyOp]:
499506
"""Return and clear all pending copy operations.
500507
@@ -579,7 +586,7 @@ def _get_or_compute_sub_hashes(self, request: Request) -> list[BlockHash]:
579586
def _index_newly_cached_blocks(
580587
self, request: Request, num_full_blocks_before: tuple[int, ...]
581588
) -> None:
582-
"""After allocate_slots caches new full blocks, index their sub-blocks."""
589+
"""Index sub-blocks for newly cached full blocks since the last call."""
583590
blocks = self.coordinator.get_blocks(request.request_id)
584591
for gi, block_list, before in zip(
585592
self._group_infos, blocks, num_full_blocks_before

vllm_rbln/v1/core/rbln_scheduler.py

Lines changed: 27 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
1516
import time
1617
from dataclasses import dataclass, field
1718

1819
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
1920
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
2021
from vllm.utils.hashing import get_hash_fn_by_name
21-
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
22+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
2223
from vllm.v1.core.kv_cache_utils import init_none_hash
2324
from vllm.v1.core.sched.interface import PauseState
2425
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
@@ -52,30 +53,6 @@ def is_prefill(request: Request) -> bool:
5253
return request.num_computed_tokens < request.num_tokens - 1
5354

5455

55-
def undo_uncomputed_block_caching(
56-
request: Request,
57-
kv_cache_manager: KVCacheManager,
58-
num_computed_tokens: int | None = None,
59-
) -> None:
60-
grouped_blocks = kv_cache_manager.get_blocks(request.request_id).blocks
61-
num_computed_blocks = [
62-
(num_computed_tokens or request.num_computed_tokens)
63-
// group.kv_cache_spec.block_size
64-
for group in kv_cache_manager.kv_cache_config.kv_cache_groups
65-
]
66-
for blocks, num_full_block in zip(grouped_blocks, num_computed_blocks):
67-
for block in blocks[num_full_block:]:
68-
# NOTE(RBLN): this function call efficiently resets
69-
# the block hash and evicts the corresponding block from the cache.
70-
kv_cache_manager.block_pool._maybe_evict_cached_block(block)
71-
72-
for manager in kv_cache_manager.coordinator.single_type_managers:
73-
# NOTE(RBLN): SingleTypeKVCacheManager instances track the number of
74-
# cached blocks of running requests in num_cached_block dictionary.
75-
if request.request_id in manager.num_cached_block:
76-
manager.num_cached_block[request.request_id] = num_full_block
77-
78-
7956
class RBLNScheduler(Scheduler):
8057
def __init__(
8158
self,
@@ -264,6 +241,8 @@ def schedule(self) -> RBLNSchedulerOutput:
264241
request,
265242
num_new_tokens,
266243
num_lookahead_tokens=self.num_lookahead_tokens,
244+
# NOTE(RBLN): Cache blocks only after scheduling is finalized.
245+
delay_cache_blocks=True,
267246
)
268247

269248
if new_blocks is not None:
@@ -390,14 +369,6 @@ def schedule(self) -> RBLNSchedulerOutput:
390369
continue
391370
new_n = spec_decode_cap
392371

393-
# Extra blocks were allocated for the original token count but
394-
# are no longer needed. Invalidate their prefix cache hash so
395-
# they are not reused incorrectly; the blocks remain allocated
396-
# and will be reused when this request needs them in a future step.
397-
undo_uncomputed_block_caching(
398-
req, self.kv_cache_manager, req.num_computed_tokens + new_n
399-
)
400-
401372
token_budget += old_n - new_n
402373
num_scheduled_tokens[req_id] = new_n
403374

@@ -643,7 +614,8 @@ def schedule(self) -> RBLNSchedulerOutput:
643614
new_computed_blocks=new_computed_blocks,
644615
num_lookahead_tokens=effective_lookahead_tokens,
645616
num_external_computed_tokens=num_external_computed_tokens,
646-
delay_cache_blocks=load_kv_async,
617+
# NOTE(RBLN): Cache blocks only after scheduling is finalized.
618+
delay_cache_blocks=True,
647619
num_encoder_tokens=num_encoder_tokens,
648620
)
649621

@@ -662,18 +634,6 @@ def schedule(self) -> RBLNSchedulerOutput:
662634
self.kv_cache_manager.apply_sub_block_match(sub_block_match)
663635
sub_block_match = None
664636

665-
# NOTE(RBLN): By calling allocate_slots with
666-
# request.num_tokens - num_computed_tokens instead of num_new_tokens,
667-
# we pre-allocate slots for all tokens that this request will prefill.
668-
# If allocated slots end up filling a block, the block hash would also
669-
# would be written down. However, since this iteration may not actually
670-
# compute all tokens, the block may not be fully computed. Therefore,
671-
# if the block is not finalized in this iteration, we must clear the
672-
# block hash and undo block caching.
673-
undo_uncomputed_block_caching(
674-
request, self.kv_cache_manager, num_computed_tokens + num_new_tokens
675-
)
676-
677637
# KVTransfer: the connector uses this info to determine
678638
# if a load is needed. Note that
679639
# This information is used to determine if a load is
@@ -763,20 +723,14 @@ def schedule(self) -> RBLNSchedulerOutput:
763723
# current step. In the next step (or after this request’s prefill
764724
# completes if it cannot finish within a single step) this request will
765725
# be scheduled together with the other running requests in the decoding
766-
# phase. We also clear the block hash written in previous allocate_slots
767-
# and undo block caching because this request and its tokens will be
768-
# scheduled again, and allocate_slots will be invoked once more and the
769-
# logic that writes the block hash will run again. Without clearing it
770-
# here, an assertion error would occur because a block hash would
771-
# already exist.
726+
# phase.
772727
for req in scheduled_running_reqs:
773728
req_to_new_blocks.pop(req.request_id)
774729
num_scheduled_tokens.pop(req.request_id)
775730
req.spec_token_ids = scheduled_spec_decode_tokens.pop(
776731
req.request_id, []
777732
)
778733
scheduled_encoder_inputs.pop(req.request_id, None)
779-
undo_uncomputed_block_caching(req, self.kv_cache_manager)
780734

781735
scheduled_running_reqs.clear()
782736
token_budget = prefill_token_budget
@@ -807,6 +761,26 @@ def schedule(self) -> RBLNSchedulerOutput:
807761
scheduled_running_reqs
808762
) <= len(self.running)
809763

764+
# NOTE(RBLN): All allocate_slots calls above used delay_cache_blocks=True
765+
# so that scheduling decisions (spec_decode_cap trimming, prefill kicking
766+
# out running decodes) can adjust token counts without needing to undo
767+
# premature caching. Now that scheduling is finalized, cache blocks and
768+
# schedule sub-block indexing for all scheduled requests.
769+
for req in itertools.chain(
770+
scheduled_running_reqs, scheduled_new_reqs, scheduled_resumed_reqs
771+
):
772+
self.kv_cache_manager.cache_blocks(
773+
req,
774+
# Cap at req.num_tokens to exclude unverified spec decode
775+
# draft tokens, matching the upstream allocate_slots behavior.
776+
min(
777+
req.num_computed_tokens + num_scheduled_tokens[req.request_id],
778+
req.num_tokens,
779+
),
780+
)
781+
if isinstance(self.kv_cache_manager, RBLNKVCacheManager):
782+
self.kv_cache_manager.schedule_sub_block_indexing(req)
783+
810784
# Get the longest common prefix among all requests in the running queue.
811785
# This can be potentially used for cascade attention.
812786
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)

0 commit comments

Comments
 (0)