Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eEable ov backend in sync mode #13

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _is_xpu() -> bool:


def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()
return _is_cuda() or _is_hip() or _is_cpu() or _is_openvino()


def _build_core_ext() -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def which_attn_to_use(

return _Backend.TORCH_SDPA

if is_openvino():
if is_openvino() or device == "openvino":
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
Expand Down
8 changes: 6 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
cpu_draft_worker: Optional[bool],
backend_device: Optional[str],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.

Expand Down Expand Up @@ -1180,7 +1181,7 @@ def maybe_create_spec_config(
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
cpu_draft_worker (Optional[bool]): Run draft model on CPU.

backend_device (Optional[str]): Select backend on CPU, such as OpenVino, Pytorch
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
Expand Down Expand Up @@ -1312,6 +1313,7 @@ def maybe_create_spec_config(
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
cpu_draft_worker=cpu_draft_worker,
backend_device=backend_device,
)

@staticmethod
Expand Down Expand Up @@ -1408,6 +1410,7 @@ def __init__(
disable_logprobs: bool,
disable_log_stats: bool,
cpu_draft_worker: Optional[bool],
backend_device: Optional[str],
):
"""Create a SpeculativeConfig object.

Expand Down Expand Up @@ -1443,6 +1446,7 @@ def __init__(
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
cpu_draft_worker: Run draft model on CPU.
backend_device (Optional[str]): Select backend on CPU, such as OpenVino, CPU(default Pytorch).
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
Expand All @@ -1459,7 +1463,7 @@ def __init__(
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self.cpu_draft_worker = cpu_draft_worker or False

self.backend_device = backend_device or 'auto'
self._verify_args()

def _verify_args(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class EngineArgs:
disable_async_output_proc: bool = False
cpu_draft_worker: Optional[bool] = False
override_neuron_config: Optional[Dict[str, Any]] = None
backend_device: Optional[bool] = None

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -953,6 +954,7 @@ def create_engine_config(self) -> EngineConfig:
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
cpu_draft_worker=self.cpu_draft_worker,
backend_device=self.backend_device,
)

if self.num_scheduler_steps > 1:
Expand Down
4 changes: 2 additions & 2 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.cpu_worker import CPUWorker
from vllm.worker.worker import Worker

from vllm.worker.openvino_worker import OpenVINOWorker

class MultiStepWorker(Worker, ProposerWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
Expand Down Expand Up @@ -368,7 +368,7 @@ def _raise_if_unsupported(


# Copied from MultiStepWorker
class CPUMultiStepWorker(CPUWorker):
class CPUMultiStepWorker:
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand Down
34 changes: 24 additions & 10 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
cpu_draft_worker=speculative_config.cpu_draft_worker)
cpu_draft_worker=speculative_config.cpu_draft_worker,
backend_device=speculative_config.backend_device)

return spec_decode_worker

Expand Down Expand Up @@ -123,6 +124,7 @@ def create_worker(
disable_logprobs: bool,
disable_log_stats: bool,
cpu_draft_worker: Optional[bool],
backend_device: Optional[str],
) -> "SpecDecodeWorker":

allow_zero_draft_token_step = True
Expand All @@ -144,24 +146,36 @@ def create_worker(
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif cpu_draft_worker:
cpu_draft_worker_kwargs = copy.deepcopy(draft_worker_kwargs)
from vllm.executor.cpu_executor import (
_verify_and_get_cache_config, _verify_and_get_model_config,
_verify_and_get_scheduler_config)
base_class = None
if backend_device == "openvino":
from vllm.executor.openvino_executor import (
_verify_and_get_cache_config, _verify_and_get_model_config)
from vllm.worker.openvino_worker import OpenVINOWorker
cpu_draft_worker_kwargs["device_config"].device_type = "openvino"
import openvino as ov
cpu_draft_worker_kwargs["kv_cache_dtype"] = ov.Type.u8
cpu_draft_worker_kwargs["cache_config"].cache_dtype = ov.Type.u8
base_class = OpenVINOWorker
else:
from vllm.executor.cpu_executor import (
_verify_and_get_cache_config, _verify_and_get_model_config,
_verify_and_get_scheduler_config)
cpu_draft_worker_kwargs["device_config"].device_type = "cpu"
from vllm.worker.cpu_worker import CPUWorker
cpu_draft_worker_kwargs["scheduler_config"] = _verify_and_get_scheduler_config(
cpu_draft_worker_kwargs["scheduler_config"])
base_class = CPUWorker
cpu_draft_worker_kwargs[
"cache_config"] = _verify_and_get_cache_config(
cpu_draft_worker_kwargs["cache_config"])
cpu_draft_worker_kwargs[
"model_config"] = _verify_and_get_model_config(
cpu_draft_worker_kwargs["model_config"])
cpu_draft_worker_kwargs[
"scheduler_config"] = _verify_and_get_scheduler_config(
cpu_draft_worker_kwargs["scheduler_config"])

cpu_draft_worker_kwargs["device_config"].device = torch.device(
"cpu")
cpu_draft_worker_kwargs["device_config"].device_type = "cpu"
cpu_draft_worker_kwargs.pop("observability_config")
proposer_worker = CPUMultiStepWorker(**cpu_draft_worker_kwargs)
cls = type('DynamicClass', (CPUMultiStepWorker, base_class), {})
proposer_worker = cls(**cpu_draft_worker_kwargs)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
Expand Down
6 changes: 5 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def __init__(
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.device.type,
"cuda",
) if num_attn_heads else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
Expand Down Expand Up @@ -1544,6 +1544,10 @@ def execute_model(
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()

model_executable = model_executable.to(self.device)
input_tokens = model_input.input_tokens.to(self.device)
input_positions = model_input.input_positions.to(self.device)

hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand Down
7 changes: 6 additions & 1 deletion vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
"openvino",
)

# Multi-modal data support
Expand Down Expand Up @@ -307,7 +308,11 @@ def prepare_input_tensors(
sampling_metadata,
multi_modal_kwargs,
)


@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()

@torch.inference_mode()
def execute_model(
self,
Expand Down
18 changes: 16 additions & 2 deletions vllm/worker/openvino_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, PromptAdapterConfig, SpeculativeConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
Expand Down Expand Up @@ -37,6 +37,7 @@ def __init__(
parallel_config: ParallelConfig,
device_config: DeviceConfig,
) -> None:
print("######device_config.device_type: ", device_config.device_type)
assert device_config.device_type == "openvino"
self.cache_config = cache_config
self.model_config = model_config
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
"openvino",
)

# Initialize the cache.
Expand Down Expand Up @@ -152,6 +154,8 @@ def __init__(
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
Expand All @@ -166,6 +170,7 @@ def __init__(
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
self.speculative_config = speculative_config
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."

Expand All @@ -192,13 +197,22 @@ def __init__(
self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]

def init_device(self) -> None:
self.init_distributed_environment()
self.device = torch.device("cpu")
# self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)

def load_model(self):
self.model_runner.load_model()

@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size

@property
def max_model_len(self) -> int:
return self.model_config.max_model_len

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.

Expand Down
Loading