1313# limitations under the License.
1414
1515import json
16+ import math
1617import os
1718from pathlib import Path
1819from typing import TYPE_CHECKING
3233logger = init_logger (__name__ )
3334
3435
36+ def is_full_block_available (num_blocks : int , vllm_config : VllmConfig ) -> bool :
37+ if vllm_config .cache_config .enable_prefix_caching :
38+ block_size = vllm_config .additional_config ["attn_block_size" ]
39+
40+ else :
41+ block_size = vllm_config .cache_config .block_size
42+
43+ max_model_len = vllm_config .model_config .max_model_len
44+ max_num_seqs = vllm_config .scheduler_config .max_num_seqs
45+
46+ blocks_per_seq = math .ceil (max_model_len / block_size )
47+ ideal_total = max_num_seqs * blocks_per_seq
48+ return num_blocks >= ideal_total
49+
50+
51+ def get_block_ratio (vllm_config : VllmConfig ) -> int :
52+ if vllm_config .cache_config .enable_prefix_caching :
53+ ob_size = vllm_config .additional_config ["attn_block_size" ]
54+ ib_size = vllm_config .cache_config .block_size
55+ blk_ratio = ob_size // ib_size
56+ else :
57+ blk_ratio = 1
58+ return blk_ratio
59+
60+
3561def get_rbln_params (
3662 vllm_config : VllmConfig , rbln_config : dict
37- ) -> tuple [int , int , int , int ]:
63+ ) -> tuple [int , int , int , int , int ]:
3864 kvcache_block_size = None
3965 prefill_chunk_size = 128
4066 batch_size = None
@@ -44,11 +70,13 @@ def get_rbln_params(
4470 max_seq_len = rbln_config .get ("dec_max_seq_len" )
4571 kvcache_block_size = max_seq_len
4672 batch_size = rbln_config .get ("batch_size" )
73+ num_blocks = rbln_config .get ("kvcache_num_blocks" )
4774 elif is_multi_modal (vllm_config .model_config .hf_config ):
4875 # Get configurations from main module (e.g. Qwen2.5-VL, Whisper)
4976 kvcache_block_size = rbln_config .get ("kvcache_block_size" )
5077 batch_size = rbln_config .get ("batch_size" )
5178 max_seq_len = rbln_config .get ("max_seq_len" )
79+ num_blocks = rbln_config .get ("kvcache_num_blocks" )
5280 if max_seq_len is None : # Whisper FIXME to be moved to enc-dec
5381 max_seq_len = rbln_config .get ("dec_max_seq_len" )
5482 # Get configurations from submodule
@@ -61,19 +89,26 @@ def get_rbln_params(
6189 )
6290 batch_size = rbln_config [submodule ].get ("batch_size" , None )
6391 max_seq_len = rbln_config [submodule ].get ("max_seq_len" , None )
92+ num_blocks = rbln_config [submodule ].get ("kvcache_num_blocks" , None )
6493 if kvcache_block_size is not None :
6594 break
6695
6796 elif is_pooling_arch (vllm_config .model_config .hf_config ):
6897 max_seq_len = rbln_config .get ("max_seq_len" )
6998 kvcache_block_size = max_seq_len
7099 batch_size = rbln_config .get ("batch_size" )
100+ num_blocks = rbln_config .get ("kvcache_num_blocks" )
101+ if num_blocks is None :
102+ num_blocks = batch_size # for pooling models, each sequence is one block
71103 else :
72104 # decoder
73105 kvcache_block_size = rbln_config .get ("kvcache_block_size" )
74106 prefill_chunk_size = rbln_config .get ("prefill_chunk_size" , 128 )
75107 batch_size = rbln_config .get ("batch_size" )
76108 max_seq_len = rbln_config .get ("max_seq_len" )
109+ num_blocks = rbln_config .get ("kvcache_num_blocks" )
110+
111+ assert num_blocks is not None , "num_blocks must be specified in rbln_config.json"
77112
78113 assert kvcache_block_size is not None , (
79114 "kvcache_block_size must be specified in rbln_config.json"
@@ -83,7 +118,7 @@ def get_rbln_params(
83118 # NOTE:
84119 # prefill_chunk_size is only used for decoder-only models
85120 # with prefix caching
86- return kvcache_block_size , batch_size , max_seq_len , prefill_chunk_size
121+ return num_blocks , batch_size , max_seq_len , kvcache_block_size , prefill_chunk_size
87122
88123
89124def set_block_size_for_prefix_caching (
@@ -132,6 +167,7 @@ def set_block_size_for_prefix_caching(
132167
133168def update_vllm_config_with_rbln_params (
134169 vllm_config : VllmConfig ,
170+ num_blocks : int ,
135171 batch_size : int ,
136172 max_model_len : int ,
137173 kvcache_block_size : int ,
@@ -181,6 +217,23 @@ def update_vllm_config_with_rbln_params(
181217 )
182218 vllm_config .cache_config .block_size = kvcache_block_size
183219
220+ # num_blocks is determined by rbln_config or overridden by user.
221+ if vllm_config .cache_config .num_gpu_blocks_override is not None :
222+ num_blocks = vllm_config .cache_config .num_gpu_blocks_override
223+ vllm_config .additional_config ["num_blocks_override" ] = num_blocks
224+
225+ blk_ratio = get_block_ratio (vllm_config )
226+
227+ if is_full_block_available (num_blocks , vllm_config ):
228+ adjusted_num_blocks = num_blocks * blk_ratio + 1
229+ else :
230+ adjusted_num_blocks = (num_blocks - 1 ) * blk_ratio + 1
231+
232+ vllm_config .cache_config .num_blocks = adjusted_num_blocks
233+
234+ if vllm_config .cache_config .num_gpu_blocks_override is not None :
235+ vllm_config .cache_config .num_gpu_blocks_override = adjusted_num_blocks
236+
184237
185238def is_qwen3_pooling (
186239 vllm_config : VllmConfig ,
@@ -215,11 +268,16 @@ def sync_with_rbln_config(vllm_config: VllmConfig) -> None:
215268 raise RuntimeError ("Failed to get RBLN config: %s" , e ) from e
216269
217270 if rbln_config is not None :
218- kvcache_block_size , batch_size , max_model_len , prefill_chunk_size = (
219- get_rbln_params (vllm_config , rbln_config )
220- )
271+ (
272+ num_blocks ,
273+ batch_size ,
274+ max_model_len ,
275+ kvcache_block_size ,
276+ prefill_chunk_size ,
277+ ) = get_rbln_params (vllm_config , rbln_config )
221278 update_vllm_config_with_rbln_params (
222279 vllm_config ,
280+ num_blocks ,
223281 batch_size ,
224282 max_model_len ,
225283 kvcache_block_size ,
0 commit comments