Skip to content

Commit 3dc044a

Browse files
authored
feature: compile optimum model in vLLM if not present (#384)
1 parent 3a6d56c commit 3dc044a

18 files changed

Lines changed: 815 additions & 216 deletions

File tree

tests/v1/core/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import patch
16+
17+
import pytest
18+
19+
20+
@pytest.fixture(autouse=True)
21+
def skip_prepare_compile():
22+
with patch("vllm_rbln.utils.optimum.configuration.prepare_vllm_for_compile"):
23+
yield

tests/v1/worker/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import shutil
17+
from unittest.mock import patch
1718

1819
import pytest
1920
import torch
@@ -28,3 +29,9 @@ def fresh_inductor_cache_per_test(monkeypatch):
2829
torch._dynamo.reset()
2930

3031
yield
32+
33+
34+
@pytest.fixture(autouse=True)
35+
def skip_prepare_compile():
36+
with patch("vllm_rbln.utils.optimum.configuration.prepare_vllm_for_compile"):
37+
yield

vllm_rbln/model_executor/models/optimum/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
def load_model(vllm_config: VllmConfig) -> nn.Module:
5252
model_config = vllm_config.model_config
53-
53+
logger.info("Loading RBLN model from %s", model_config.model)
5454
if is_multi_modal(model_config.hf_config):
5555
assert vllm_config.cache_config.enable_prefix_caching in (False, None), (
5656
"Prefix caching is not supported with multimodal models. "

vllm_rbln/model_executor/models/optimum/model_base.py

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,37 @@
2525
from vllm.v1.sample.metadata import SamplingMetadata
2626

2727
import optimum.rbln
28+
import vllm_rbln.rbln_envs as envs
2829
from optimum.rbln.transformers.models.decoderonly import (
2930
decoderonly_runtime_utils as runtime_utils,
3031
)
3132
from vllm_rbln.utils.optimum.common import select_bucket_size
32-
from vllm_rbln.utils.optimum.registry import get_rbln_model_info
33+
from vllm_rbln.utils.optimum.registry import compile_model, get_rbln_model_info
3334

3435
logger = init_logger(__name__)
3536

3637

38+
def get_attn_block_size(vllm_config: VllmConfig) -> int:
39+
if vllm_config.cache_config.enable_prefix_caching:
40+
block_size = vllm_config.additional_config["attn_block_size"]
41+
else:
42+
block_size = vllm_config.cache_config.block_size
43+
return block_size
44+
45+
46+
def generate_model_path_name(
47+
model_name: str,
48+
batch_size: int,
49+
block_size: int,
50+
max_model_len: int,
51+
tp_size: int,
52+
) -> str:
53+
# FIXME: To avoid cache collisions, the cache key should also include
54+
# the versions of the compiler and optimum-rbln.
55+
model_name = model_name.replace("/", "_").replace(":", "_")
56+
return f"{model_name}_bs{batch_size}_blk{block_size}_msl{max_model_len}_tp{tp_size}"
57+
58+
3759
class KVCacheBlockAdapter:
3860
"""
3961
KV cache block allocation behavior (v1 vs v0).
@@ -81,12 +103,7 @@ def _estimated_num_blocks(self) -> int:
81103
def is_full_block_available(self) -> bool:
82104
"""True if we can allocate a full batch worth of blocks."""
83105
estimated = self._estimated_num_blocks()
84-
85-
if self.vllm_config.cache_config.enable_prefix_caching:
86-
block_size = self.vllm_config.additional_config["attn_block_size"]
87-
88-
else:
89-
block_size = self.vllm_config.cache_config.block_size
106+
block_size = get_attn_block_size(self.vllm_config)
90107

91108
max_model_len = self.vllm_config.model_config.max_model_len
92109
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
@@ -145,39 +162,76 @@ def _resolve_kvcache_num_blocks(self) -> int:
145162
return int(self.scheduler_config.max_num_seqs)
146163

147164
def init_model(self) -> None:
165+
# Check if the model is already compiled and load it;
166+
# else compile the model and load it.
148167
config = self.model_config.hf_config
149-
model_name, model_cls_name = get_rbln_model_info(config)
150-
151168
if isinstance(self.model_config.model, str | Path) and os.path.exists(
152169
self.model_config.model
153170
):
154171
model_path = Path(self.model_config.model)
155172
if model_path.is_dir() and any(model_path.glob("rbln_config.json")):
156-
compiled_path = self.model_config.model
173+
is_compiled_model = True
157174
else:
158-
compiled_path = None
175+
is_compiled_model = False
159176
else:
160-
compiled_path = None
177+
is_compiled_model = False
161178

162-
if compiled_path is None or not os.path.exists(compiled_path):
163-
raise RuntimeError(f"Compiled model path does not exist: {compiled_path}")
164-
165-
# huggingface model class name
166-
logger.info(
167-
"model_name = %s, model_cls_name = %s, model_path = %s",
168-
model_name,
169-
model_cls_name,
170-
compiled_path,
171-
)
179+
model_name, model_cls_name = get_rbln_model_info(config)
180+
model = None
181+
182+
# If a HuggingFace model (not optimum-compiled) is given,
183+
# look up the cached compiled model.
184+
# If it does not exist, compile and save it to the cache for future use.
185+
if not is_compiled_model:
186+
model_path_name = generate_model_path_name(
187+
self.model_config.model,
188+
batch_size=self.scheduler_config.max_num_seqs,
189+
block_size=get_attn_block_size(self.vllm_config),
190+
max_model_len=self.model_config.max_model_len,
191+
tp_size=envs.VLLM_RBLN_TP_SIZE,
192+
)
193+
cached_model_path = os.path.join(
194+
envs.VLLM_CACHE_ROOT,
195+
"compiled_models/" + model_path_name,
196+
)
197+
if not os.path.exists(cached_model_path):
198+
logger.info(
199+
"Compiling the model %s. This may take a while...",
200+
self.model_config.model,
201+
)
202+
model = compile_model(
203+
self.model_config.model,
204+
config,
205+
batch_size=self.scheduler_config.max_num_seqs,
206+
block_size=get_attn_block_size(self.vllm_config),
207+
max_model_len=self.model_config.max_model_len,
208+
tp_size=envs.VLLM_RBLN_TP_SIZE,
209+
model_path=str(cached_model_path),
210+
)
211+
else:
212+
logger.info(
213+
"Found compiled model at %s. Loading the model from the path.",
214+
cached_model_path,
215+
)
216+
self.vllm_config.model_config.model = cached_model_path
217+
218+
# Load the model directly if it is either an optimum-compiled model
219+
# or a HuggingFace model that has already been compiled and cached.
220+
if model is None:
221+
model_cls = getattr(optimum.rbln, model_cls_name)
222+
assert model_cls is not None
223+
model = model_cls.from_pretrained(self.vllm_config.model_config.model)
224+
logger.info(
225+
"model_name = %s, model_cls_name = %s, model_path = %s",
226+
model_name,
227+
model_cls_name,
228+
self.vllm_config.model_config.model,
229+
)
172230

173231
self.supports_transcription_only = (
174232
model_cls_name == "RBLNOptimumWhisperForConditionalGeneration"
175233
)
176234

177-
# huggingface model class
178-
model_cls = getattr(optimum.rbln, model_cls_name)
179-
assert model_cls is not None
180-
model = model_cls.from_pretrained(compiled_path, export=False)
181235
self.model = model
182236
self.rbln_model_config = model.rbln_config
183237
self.attn_impl = (

vllm_rbln/platform.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
239239
)
240240

241241
assert vllm_config.parallel_config.tensor_parallel_size == 1, (
242-
"Tensor parallelism is set when compiled in optimum-rbln."
242+
"Cannot set tensor_parallel_size for pre-compiled optimum-rbln models. "
243+
"If you want to compile with tensor parallelism in vllm-rbln, "
244+
"please use the `VLLM_RBLN_TP_SIZE` environment variable instead."
243245
)
244246
assert vllm_config.parallel_config.pipeline_parallel_size == 1, (
245247
"Pipeline parallelism is not supported in optimum-rbln."
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""KV-cache block calculation and synchronisation helpers."""
16+
17+
import math
18+
from typing import TYPE_CHECKING
19+
20+
if TYPE_CHECKING:
21+
from vllm.config import VllmConfig
22+
else:
23+
VllmConfig = None
24+
25+
from vllm_rbln.logger import init_logger
26+
27+
logger = init_logger(__name__)
28+
29+
30+
def is_full_block_available(num_blocks: int, vllm_config: VllmConfig) -> bool:
31+
if vllm_config.cache_config.enable_prefix_caching:
32+
block_size = vllm_config.additional_config["attn_block_size"]
33+
34+
else:
35+
block_size = vllm_config.cache_config.block_size
36+
37+
max_model_len = vllm_config.model_config.max_model_len
38+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
39+
40+
blocks_per_seq = math.ceil(max_model_len / block_size)
41+
ideal_total = max_num_seqs * blocks_per_seq
42+
return num_blocks >= ideal_total
43+
44+
45+
def get_block_ratio(vllm_config: VllmConfig) -> int:
46+
if vllm_config.cache_config.enable_prefix_caching:
47+
ob_size = vllm_config.additional_config["attn_block_size"]
48+
ib_size = vllm_config.cache_config.block_size
49+
blk_ratio = ob_size // ib_size
50+
else:
51+
blk_ratio = 1
52+
return blk_ratio
53+
54+
55+
def apply_prefix_caching_block_size(
56+
vllm_config: VllmConfig, kvcache_block_size: int, prefill_chunk_size: int
57+
) -> None:
58+
assert prefill_chunk_size is not None, (
59+
"prefill_chunk_size must be specified in rbln_config.json"
60+
)
61+
# If user set prefix_block_size in additional_config, use it.
62+
# Otherwise, set it to prefill_chunk_size.
63+
prefix_block_size = vllm_config.additional_config.get("prefix_block_size", None)
64+
if prefix_block_size is None:
65+
prefix_block_size = prefill_chunk_size
66+
logger.debug(
67+
"Prefix block size is set to %s based on prefill_chunk_size",
68+
prefix_block_size,
69+
)
70+
else:
71+
if prefix_block_size % prefill_chunk_size != 0:
72+
raise ValueError(
73+
"prefix_block_size ({}) is not divisible "
74+
"by prefill_chunk_size ({}). "
75+
"Please check the value of prefill_chunk_size "
76+
"in rbln_config.json".format(prefix_block_size, prefill_chunk_size)
77+
)
78+
if prefix_block_size > kvcache_block_size:
79+
raise ValueError(
80+
"prefix_block_size ({}) is greater than "
81+
"kvcache_block_size ({}). "
82+
"Please check the value of kvcache_block_size "
83+
"in rbln_config.json".format(prefix_block_size, kvcache_block_size)
84+
)
85+
logger.debug(
86+
"Prefix block size is set to %s based on additional_config",
87+
prefix_block_size,
88+
)
89+
if kvcache_block_size % prefix_block_size != 0:
90+
raise ValueError(
91+
"kvcache_block_size ({}) is not divisible "
92+
"by prefix_block_size ({}). "
93+
"Please check the value of prefix_block_size in rbln_config.json".format(
94+
kvcache_block_size, prefix_block_size
95+
)
96+
)
97+
vllm_config.cache_config.block_size = prefix_block_size
98+
vllm_config.additional_config["attn_block_size"] = kvcache_block_size
99+
100+
101+
def sync_cache_block_size(
102+
vllm_config: VllmConfig, kvcache_block_size: int, prefill_chunk_size: int
103+
) -> None:
104+
if vllm_config.cache_config.enable_prefix_caching:
105+
apply_prefix_caching_block_size(
106+
vllm_config, kvcache_block_size, prefill_chunk_size
107+
)
108+
else:
109+
if vllm_config.cache_config.block_size != kvcache_block_size:
110+
logger.info(
111+
"Updating model_cache_config.block_size from %s to %s "
112+
"based on rbln_config.json",
113+
vllm_config.cache_config.block_size,
114+
kvcache_block_size,
115+
)
116+
vllm_config.cache_config.block_size = kvcache_block_size
117+
118+
119+
def sync_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> None:
120+
# num_blocks is determined by rbln_config or overridden by user.
121+
if vllm_config.cache_config.num_gpu_blocks_override is not None:
122+
num_blocks = vllm_config.cache_config.num_gpu_blocks_override
123+
vllm_config.additional_config["num_blocks_override"] = num_blocks
124+
125+
blk_ratio = get_block_ratio(vllm_config)
126+
127+
if is_full_block_available(num_blocks, vllm_config):
128+
adjusted_num_blocks = num_blocks * blk_ratio + 1
129+
else:
130+
adjusted_num_blocks = (num_blocks - 1) * blk_ratio + 1
131+
132+
vllm_config.cache_config.num_gpu_blocks = adjusted_num_blocks
133+
134+
if vllm_config.cache_config.num_gpu_blocks_override is not None:
135+
vllm_config.cache_config.num_gpu_blocks_override = adjusted_num_blocks

0 commit comments

Comments
 (0)