Skip to content

Commit eae59cd

Browse files
eEable ov backend in sync mode (#13)
Co-authored-by: xiaodong <[email protected]>
1 parent b7ccf41 commit eae59cd

9 files changed

+63
-20
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def _is_xpu() -> bool:
282282

283283

284284
def _build_custom_ops() -> bool:
285-
return _is_cuda() or _is_hip() or _is_cpu()
285+
return _is_cuda() or _is_hip() or _is_cpu() or _is_openvino()
286286

287287

288288
def _build_core_ext() -> bool:

vllm/attention/selector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def which_attn_to_use(
187187

188188
return _Backend.TORCH_SDPA
189189

190-
if is_openvino():
190+
if is_openvino() or device == "openvino":
191191
if selected_backend != _Backend.OPENVINO:
192192
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
193193
return _Backend.OPENVINO

vllm/config.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,7 @@ def maybe_create_spec_config(
11221122
typical_acceptance_sampler_posterior_alpha: Optional[float],
11231123
disable_logprobs: Optional[bool],
11241124
cpu_draft_worker: Optional[bool],
1125+
backend_device: Optional[str],
11251126
) -> Optional["SpeculativeConfig"]:
11261127
"""Create a SpeculativeConfig if possible, else return None.
11271128
@@ -1180,7 +1181,7 @@ def maybe_create_spec_config(
11801181
according to the log probability settings in SamplingParams.
11811182
If not specified, it defaults to True.
11821183
cpu_draft_worker (Optional[bool]): Run draft model on CPU.
1183-
1184+
backend_device (Optional[str]): Select backend on CPU, such as OpenVino, Pytorch
11841185
Returns:
11851186
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
11861187
the necessary conditions are met, else None.
@@ -1312,6 +1313,7 @@ def maybe_create_spec_config(
13121313
disable_logprobs=disable_logprobs,
13131314
disable_log_stats=disable_log_stats,
13141315
cpu_draft_worker=cpu_draft_worker,
1316+
backend_device=backend_device,
13151317
)
13161318

13171319
@staticmethod
@@ -1408,6 +1410,7 @@ def __init__(
14081410
disable_logprobs: bool,
14091411
disable_log_stats: bool,
14101412
cpu_draft_worker: Optional[bool],
1413+
backend_device: Optional[str],
14111414
):
14121415
"""Create a SpeculativeConfig object.
14131416
@@ -1443,6 +1446,7 @@ def __init__(
14431446
disable_log_stats: Whether to disable periodic printing of stage
14441447
times in speculative decoding.
14451448
cpu_draft_worker: Run draft model on CPU.
1449+
backend_device (Optional[str]): Select backend on CPU, such as OpenVino, CPU(default Pytorch).
14461450
"""
14471451
self.draft_model_config = draft_model_config
14481452
self.draft_parallel_config = draft_parallel_config
@@ -1459,7 +1463,7 @@ def __init__(
14591463
self.disable_logprobs = disable_logprobs
14601464
self.disable_log_stats = disable_log_stats
14611465
self.cpu_draft_worker = cpu_draft_worker or False
1462-
1466+
self.backend_device = backend_device or 'auto'
14631467
self._verify_args()
14641468

14651469
def _verify_args(self) -> None:

vllm/engine/arg_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ class EngineArgs:
162162
disable_async_output_proc: bool = False
163163
cpu_draft_worker: Optional[bool] = False
164164
override_neuron_config: Optional[Dict[str, Any]] = None
165+
backend_device: Optional[bool] = None
165166

166167
def __post_init__(self):
167168
if self.tokenizer is None:
@@ -953,6 +954,7 @@ def create_engine_config(self) -> EngineConfig:
953954
typical_acceptance_sampler_posterior_alpha,
954955
disable_logprobs=self.disable_logprobs_during_spec_decoding,
955956
cpu_draft_worker=self.cpu_draft_worker,
957+
backend_device=self.backend_device,
956958
)
957959

958960
if self.num_scheduler_steps > 1:

vllm/spec_decode/multi_step_worker.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.spec_decode.top1_proposer import Top1Proposer
1515
from vllm.worker.cpu_worker import CPUWorker
1616
from vllm.worker.worker import Worker
17-
17+
from vllm.worker.openvino_worker import OpenVINOWorker
1818

1919
class MultiStepWorker(Worker, ProposerWorkerBase):
2020
"""The MultiStepWorker is equivalent to a Worker except that it allows
@@ -368,7 +368,7 @@ def _raise_if_unsupported(
368368

369369

370370
# Copied from MultiStepWorker
371-
class CPUMultiStepWorker(CPUWorker):
371+
class CPUMultiStepWorker:
372372
"""The MultiStepWorker is equivalent to a Worker except that it allows
373373
multiple forward passes in a single call, assuming the scheduler has
374374
allocated enough space to store the additional KV. This reduces overhead

vllm/spec_decode/spec_decode_worker.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
8080
typical_acceptance_sampler_posterior_alpha,
8181
disable_logprobs=speculative_config.disable_logprobs,
8282
disable_log_stats=speculative_config.disable_log_stats,
83-
cpu_draft_worker=speculative_config.cpu_draft_worker)
83+
cpu_draft_worker=speculative_config.cpu_draft_worker,
84+
backend_device=speculative_config.backend_device)
8485

8586
return spec_decode_worker
8687

@@ -123,6 +124,7 @@ def create_worker(
123124
disable_logprobs: bool,
124125
disable_log_stats: bool,
125126
cpu_draft_worker: Optional[bool],
127+
backend_device: Optional[str],
126128
) -> "SpecDecodeWorker":
127129

128130
allow_zero_draft_token_step = True
@@ -144,24 +146,36 @@ def create_worker(
144146
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
145147
elif cpu_draft_worker:
146148
cpu_draft_worker_kwargs = copy.deepcopy(draft_worker_kwargs)
147-
from vllm.executor.cpu_executor import (
148-
_verify_and_get_cache_config, _verify_and_get_model_config,
149-
_verify_and_get_scheduler_config)
149+
base_class = None
150+
if backend_device == "openvino":
151+
from vllm.executor.openvino_executor import (
152+
_verify_and_get_cache_config, _verify_and_get_model_config)
153+
from vllm.worker.openvino_worker import OpenVINOWorker
154+
cpu_draft_worker_kwargs["device_config"].device_type = "openvino"
155+
import openvino as ov
156+
cpu_draft_worker_kwargs["kv_cache_dtype"] = ov.Type.u8
157+
cpu_draft_worker_kwargs["cache_config"].cache_dtype = ov.Type.u8
158+
base_class = OpenVINOWorker
159+
else:
160+
from vllm.executor.cpu_executor import (
161+
_verify_and_get_cache_config, _verify_and_get_model_config,
162+
_verify_and_get_scheduler_config)
163+
cpu_draft_worker_kwargs["device_config"].device_type = "cpu"
164+
from vllm.worker.cpu_worker import CPUWorker
165+
cpu_draft_worker_kwargs["scheduler_config"] = _verify_and_get_scheduler_config(
166+
cpu_draft_worker_kwargs["scheduler_config"])
167+
base_class = CPUWorker
150168
cpu_draft_worker_kwargs[
151169
"cache_config"] = _verify_and_get_cache_config(
152170
cpu_draft_worker_kwargs["cache_config"])
153171
cpu_draft_worker_kwargs[
154172
"model_config"] = _verify_and_get_model_config(
155173
cpu_draft_worker_kwargs["model_config"])
156-
cpu_draft_worker_kwargs[
157-
"scheduler_config"] = _verify_and_get_scheduler_config(
158-
cpu_draft_worker_kwargs["scheduler_config"])
159-
160174
cpu_draft_worker_kwargs["device_config"].device = torch.device(
161175
"cpu")
162-
cpu_draft_worker_kwargs["device_config"].device_type = "cpu"
163176
cpu_draft_worker_kwargs.pop("observability_config")
164-
proposer_worker = CPUMultiStepWorker(**cpu_draft_worker_kwargs)
177+
cls = type('DynamicClass', (CPUMultiStepWorker, base_class), {})
178+
proposer_worker = cls(**cpu_draft_worker_kwargs)
165179
elif draft_worker_kwargs[
166180
"model_config"].hf_config.model_type == "medusa":
167181
proposer_worker = MedusaWorker(**draft_worker_kwargs)

vllm/worker/model_runner.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def __init__(
965965
self.model_config.dtype,
966966
self.kv_cache_dtype,
967967
self.block_size,
968-
self.device.type,
968+
"cuda",
969969
) if num_attn_heads else None
970970
if self.attn_backend:
971971
self.attn_state = self.attn_backend.get_state_cls()(
@@ -1544,6 +1544,10 @@ def execute_model(
15441544
model_forward_end = torch.cuda.Event(enable_timing=True)
15451545
model_forward_start.record()
15461546

1547+
model_executable = model_executable.to(self.device)
1548+
input_tokens = model_input.input_tokens.to(self.device)
1549+
input_positions = model_input.input_positions.to(self.device)
1550+
15471551
hidden_or_intermediate_states = model_executable(
15481552
input_ids=model_input.input_tokens,
15491553
positions=model_input.input_positions,

vllm/worker/openvino_model_runner.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
self.model_config.dtype,
8080
self.kv_cache_dtype,
8181
self.block_size,
82+
"openvino",
8283
)
8384

8485
# Multi-modal data support
@@ -307,7 +308,11 @@ def prepare_input_tensors(
307308
sampling_metadata,
308309
multi_modal_kwargs,
309310
)
310-
311+
312+
@property
313+
def vocab_size(self) -> int:
314+
return self.model_config.get_vocab_size()
315+
311316
@torch.inference_mode()
312317
def execute_model(
313318
self,

vllm/worker/openvino_worker.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.attention import get_attn_backend
99
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
1010
ModelConfig, MultiModalConfig, ParallelConfig,
11-
SchedulerConfig)
11+
SchedulerConfig, PromptAdapterConfig, SpeculativeConfig)
1212
from vllm.distributed import (broadcast_tensor_dict,
1313
ensure_model_parallel_initialized,
1414
init_distributed_environment)
@@ -37,6 +37,7 @@ def __init__(
3737
parallel_config: ParallelConfig,
3838
device_config: DeviceConfig,
3939
) -> None:
40+
print("######device_config.device_type: ", device_config.device_type)
4041
assert device_config.device_type == "openvino"
4142
self.cache_config = cache_config
4243
self.model_config = model_config
@@ -69,6 +70,7 @@ def __init__(
6970
self.model_config.dtype,
7071
self.cache_config.cache_dtype,
7172
self.block_size,
73+
"openvino",
7274
)
7375

7476
# Initialize the cache.
@@ -152,6 +154,8 @@ def __init__(
152154
multimodal_config: Optional[MultiModalConfig] = None,
153155
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
154156
is_driver_worker: bool = False,
157+
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
158+
speculative_config: Optional[SpeculativeConfig] = None,
155159
) -> None:
156160
self.model_config = model_config
157161
self.parallel_config = parallel_config
@@ -166,6 +170,7 @@ def __init__(
166170
self.lora_config = lora_config
167171
self.multimodal_config = multimodal_config
168172
self.is_driver_worker = is_driver_worker
173+
self.speculative_config = speculative_config
169174
if self.is_driver_worker:
170175
assert self.rank == 0, "The driver worker must have rank 0."
171176

@@ -192,13 +197,22 @@ def __init__(
192197
self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]
193198

194199
def init_device(self) -> None:
195-
self.init_distributed_environment()
200+
self.device = torch.device("cpu")
201+
# self.init_distributed_environment()
196202
# Set random seed.
197203
set_random_seed(self.model_config.seed)
198204

199205
def load_model(self):
200206
self.model_runner.load_model()
201207

208+
@property
209+
def vocab_size(self) -> int:
210+
return self.model_runner.vocab_size
211+
212+
@property
213+
def max_model_len(self) -> int:
214+
return self.model_config.max_model_len
215+
202216
def determine_num_available_blocks(self) -> Tuple[int, int]:
203217
"""Determine the number of blocks available for the KV cache.
204218

0 commit comments

Comments
 (0)