Skip to content

Commit 575a316

Browse files
authored
fix(core): block allocation for torch compile path (#77)
* fix(v1): use reserved null block for padding * fix(v0): match min block requirements with scheduling heuristics * lint: format codes
1 parent c31a7ae commit 575a316

3 files changed

Lines changed: 13 additions & 11 deletions

File tree

vllm_rbln/v1/attention/backends/flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def build(
369369
block_table_tensor,
370370
torch.full(
371371
(batch_padding_size, block_table_tensor.shape[-1]),
372-
block_table_tensor.numel() - 1,
372+
0,
373373
),
374374
])
375375
decode_attention_mask = torch.zeros(

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,10 +1581,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
15811581
self.initialize_attn_backend(kv_cache_config)
15821582
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
15831583

1584-
# for partition skip, we need dummy block slot.
1585-
no_dummy_slots = 1
1586-
kv_cache_config.num_blocks -= no_dummy_slots
1587-
15881584
if self.speculative_config and self.speculative_config.use_eagle():
15891585
assert isinstance(self.drafter, EagleProposer)
15901586
# validate all draft model layers belong to the same kv cache

vllm_rbln/worker/worker.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@ def __init__(
9999

100100
def _allocate_kv_cache(self, ) -> List[torch.Tensor]:
101101
"""Allocates KV cache on RBLN."""
102+
103+
# One extra block is reserved for padding.
102104
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
103-
self.num_cpu_blocks, self.block_size, self.num_heads,
105+
self.num_cpu_blocks + 1, self.block_size, self.num_heads,
104106
self.head_size)
105107
kv_cache: List[torch.Tensor] = []
106108
logger.info("[RBLN] attention backend get_kv_cache_shape = %s",
@@ -303,12 +305,16 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
303305
# 1 : prefill
304306
num_runtimes=1 + self.scheduler_config.max_num_seqs)
305307

306-
max_required_num_blocks = (self.model_config.max_model_len *
307-
self.scheduler_config.max_num_seqs //
308-
block_size)
309-
310-
num_gpu_blocks = min(max_num_blocks - 1, max_required_num_blocks)
308+
max_required_num_blocks = (
309+
self.model_config.max_model_len *
310+
self.scheduler_config.max_num_seqs //
311+
block_size) + self.scheduler_config.max_num_seqs + 1
311312

313+
# We always allocate this number of blocks, but the last one is
314+
# reserved for padding. As a result, the vLLM system should treat
315+
# it as if there is one fewer usable block than the number
316+
# actually allocated.
317+
num_gpu_blocks = min(max_num_blocks, max_required_num_blocks) - 1
312318
if npu_num_blocks := os.environ.get("VLLM_RBLN_NPU_NUM_BLOCKS"):
313319
num_gpu_blocks = int(npu_num_blocks) - 1
314320

0 commit comments

Comments
 (0)