Skip to content

Commit 0a98aba

Browse files
authored
fix: set num_blocks outside of model_runner (#410)
1 parent 6546c11 commit 0a98aba

3 files changed

Lines changed: 76 additions & 7 deletions

File tree

vllm_rbln/model_executor/models/optimum/model_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,13 @@ def _env_int(name: str, default: int) -> int:
6767

6868
def _estimated_num_blocks(self) -> int:
6969
"""Override estimated blocks if num_gpu_blocks_override is set."""
70-
num_gpu_blocks_override = self.vllm_config.cache_config.num_gpu_blocks_override
71-
if num_gpu_blocks_override is not None:
70+
if (
71+
self.vllm_config.additional_config
72+
and "num_blocks_override" in self.vllm_config.additional_config
73+
):
74+
num_gpu_blocks_override = self.vllm_config.additional_config[
75+
"num_blocks_override"
76+
]
7277
return num_gpu_blocks_override
7378
else:
7479
return int(self.estimated_kvcache_num_blocks)

vllm_rbln/utils/optimum/configuration.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16+
import math
1617
import os
1718
from pathlib import Path
1819
from typing import TYPE_CHECKING
@@ -32,9 +33,34 @@
3233
logger = init_logger(__name__)
3334

3435

36+
def is_full_block_available(num_blocks: int, vllm_config: VllmConfig) -> bool:
37+
if vllm_config.cache_config.enable_prefix_caching:
38+
block_size = vllm_config.additional_config["attn_block_size"]
39+
40+
else:
41+
block_size = vllm_config.cache_config.block_size
42+
43+
max_model_len = vllm_config.model_config.max_model_len
44+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
45+
46+
blocks_per_seq = math.ceil(max_model_len / block_size)
47+
ideal_total = max_num_seqs * blocks_per_seq
48+
return num_blocks >= ideal_total
49+
50+
51+
def get_block_ratio(vllm_config: VllmConfig) -> int:
52+
if vllm_config.cache_config.enable_prefix_caching:
53+
ob_size = vllm_config.additional_config["attn_block_size"]
54+
ib_size = vllm_config.cache_config.block_size
55+
blk_ratio = ob_size // ib_size
56+
else:
57+
blk_ratio = 1
58+
return blk_ratio
59+
60+
3561
def get_rbln_params(
3662
vllm_config: VllmConfig, rbln_config: dict
37-
) -> tuple[int, int, int, int]:
63+
) -> tuple[int, int, int, int, int]:
3864
kvcache_block_size = None
3965
prefill_chunk_size = 128
4066
batch_size = None
@@ -44,11 +70,13 @@ def get_rbln_params(
4470
max_seq_len = rbln_config.get("dec_max_seq_len")
4571
kvcache_block_size = max_seq_len
4672
batch_size = rbln_config.get("batch_size")
73+
num_blocks = rbln_config.get("kvcache_num_blocks")
4774
elif is_multi_modal(vllm_config.model_config.hf_config):
4875
# Get configurations from main module (e.g. Qwen2.5-VL, Whisper)
4976
kvcache_block_size = rbln_config.get("kvcache_block_size")
5077
batch_size = rbln_config.get("batch_size")
5178
max_seq_len = rbln_config.get("max_seq_len")
79+
num_blocks = rbln_config.get("kvcache_num_blocks")
5280
if max_seq_len is None: # Whisper FIXME to be moved to enc-dec
5381
max_seq_len = rbln_config.get("dec_max_seq_len")
5482
# Get configurations from submodule
@@ -61,19 +89,26 @@ def get_rbln_params(
6189
)
6290
batch_size = rbln_config[submodule].get("batch_size", None)
6391
max_seq_len = rbln_config[submodule].get("max_seq_len", None)
92+
num_blocks = rbln_config[submodule].get("kvcache_num_blocks", None)
6493
if kvcache_block_size is not None:
6594
break
6695

6796
elif is_pooling_arch(vllm_config.model_config.hf_config):
6897
max_seq_len = rbln_config.get("max_seq_len")
6998
kvcache_block_size = max_seq_len
7099
batch_size = rbln_config.get("batch_size")
100+
num_blocks = rbln_config.get("kvcache_num_blocks")
101+
if num_blocks is None:
102+
num_blocks = batch_size # for pooling models, each sequence is one block
71103
else:
72104
# decoder
73105
kvcache_block_size = rbln_config.get("kvcache_block_size")
74106
prefill_chunk_size = rbln_config.get("prefill_chunk_size", 128)
75107
batch_size = rbln_config.get("batch_size")
76108
max_seq_len = rbln_config.get("max_seq_len")
109+
num_blocks = rbln_config.get("kvcache_num_blocks")
110+
111+
assert num_blocks is not None, "num_blocks must be specified in rbln_config.json"
77112

78113
assert kvcache_block_size is not None, (
79114
"kvcache_block_size must be specified in rbln_config.json"
@@ -83,7 +118,7 @@ def get_rbln_params(
83118
# NOTE:
84119
# prefill_chunk_size is only used for decoder-only models
85120
# with prefix caching
86-
return kvcache_block_size, batch_size, max_seq_len, prefill_chunk_size
121+
return num_blocks, batch_size, max_seq_len, kvcache_block_size, prefill_chunk_size
87122

88123

89124
def set_block_size_for_prefix_caching(
@@ -132,6 +167,7 @@ def set_block_size_for_prefix_caching(
132167

133168
def update_vllm_config_with_rbln_params(
134169
vllm_config: VllmConfig,
170+
num_blocks: int,
135171
batch_size: int,
136172
max_model_len: int,
137173
kvcache_block_size: int,
@@ -181,6 +217,23 @@ def update_vllm_config_with_rbln_params(
181217
)
182218
vllm_config.cache_config.block_size = kvcache_block_size
183219

220+
# num_blocks is determined by rbln_config or overridden by user.
221+
if vllm_config.cache_config.num_gpu_blocks_override is not None:
222+
num_blocks = vllm_config.cache_config.num_gpu_blocks_override
223+
vllm_config.additional_config["num_blocks_override"] = num_blocks
224+
225+
blk_ratio = get_block_ratio(vllm_config)
226+
227+
if is_full_block_available(num_blocks, vllm_config):
228+
adjusted_num_blocks = num_blocks * blk_ratio + 1
229+
else:
230+
adjusted_num_blocks = (num_blocks - 1) * blk_ratio + 1
231+
232+
vllm_config.cache_config.num_blocks = adjusted_num_blocks
233+
234+
if vllm_config.cache_config.num_gpu_blocks_override is not None:
235+
vllm_config.cache_config.num_gpu_blocks_override = adjusted_num_blocks
236+
184237

185238
def is_qwen3_pooling(
186239
vllm_config: VllmConfig,
@@ -215,11 +268,16 @@ def sync_with_rbln_config(vllm_config: VllmConfig) -> None:
215268
raise RuntimeError("Failed to get RBLN config: %s", e) from e
216269

217270
if rbln_config is not None:
218-
kvcache_block_size, batch_size, max_model_len, prefill_chunk_size = (
219-
get_rbln_params(vllm_config, rbln_config)
220-
)
271+
(
272+
num_blocks,
273+
batch_size,
274+
max_model_len,
275+
kvcache_block_size,
276+
prefill_chunk_size,
277+
) = get_rbln_params(vllm_config, rbln_config)
221278
update_vllm_config_with_rbln_params(
222279
vllm_config,
280+
num_blocks,
223281
batch_size,
224282
max_model_len,
225283
kvcache_block_size,

vllm_rbln/v1/worker/optimum_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ def determine_available_memory(self) -> int:
112112

113113
adapter = self.model_runner.model.kv_block_adapter
114114
num_gpu_blocks = adapter.get_available_num_blocks()
115+
validation_blocks = self.model_runner.vllm_config.cache_config.num_blocks
116+
# This will be removed after validation check
117+
assert num_gpu_blocks == validation_blocks, (
118+
f"The number of blocks from the model runner ({num_gpu_blocks}) "
119+
f"and the platform ({validation_blocks}) must be the same."
120+
)
115121

116122
return num_gpu_blocks * page_size * num_layers
117123

0 commit comments

Comments
 (0)