@@ -80,7 +80,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
80
80
typical_acceptance_sampler_posterior_alpha ,
81
81
disable_logprobs = speculative_config .disable_logprobs ,
82
82
disable_log_stats = speculative_config .disable_log_stats ,
83
- cpu_draft_worker = speculative_config .cpu_draft_worker )
83
+ cpu_draft_worker = speculative_config .cpu_draft_worker ,
84
+ backend_device = speculative_config .backend_device )
84
85
85
86
return spec_decode_worker
86
87
@@ -123,6 +124,7 @@ def create_worker(
123
124
disable_logprobs : bool ,
124
125
disable_log_stats : bool ,
125
126
cpu_draft_worker : Optional [bool ],
127
+ backend_device : Optional [str ],
126
128
) -> "SpecDecodeWorker" :
127
129
128
130
allow_zero_draft_token_step = True
@@ -144,24 +146,36 @@ def create_worker(
144
146
proposer_worker = MLPSpeculatorWorker (** draft_worker_kwargs )
145
147
elif cpu_draft_worker :
146
148
cpu_draft_worker_kwargs = copy .deepcopy (draft_worker_kwargs )
147
- from vllm .executor .cpu_executor import (
148
- _verify_and_get_cache_config , _verify_and_get_model_config ,
149
- _verify_and_get_scheduler_config )
149
+ base_class = None
150
+ if backend_device == "openvino" :
151
+ from vllm .executor .openvino_executor import (
152
+ _verify_and_get_cache_config , _verify_and_get_model_config )
153
+ from vllm .worker .openvino_worker import OpenVINOWorker
154
+ cpu_draft_worker_kwargs ["device_config" ].device_type = "openvino"
155
+ import openvino as ov
156
+ cpu_draft_worker_kwargs ["kv_cache_dtype" ] = ov .Type .u8
157
+ cpu_draft_worker_kwargs ["cache_config" ].cache_dtype = ov .Type .u8
158
+ base_class = OpenVINOWorker
159
+ else :
160
+ from vllm .executor .cpu_executor import (
161
+ _verify_and_get_cache_config , _verify_and_get_model_config ,
162
+ _verify_and_get_scheduler_config )
163
+ cpu_draft_worker_kwargs ["device_config" ].device_type = "cpu"
164
+ from vllm .worker .cpu_worker import CPUWorker
165
+ cpu_draft_worker_kwargs ["scheduler_config" ] = _verify_and_get_scheduler_config (
166
+ cpu_draft_worker_kwargs ["scheduler_config" ])
167
+ base_class = CPUWorker
150
168
cpu_draft_worker_kwargs [
151
169
"cache_config" ] = _verify_and_get_cache_config (
152
170
cpu_draft_worker_kwargs ["cache_config" ])
153
171
cpu_draft_worker_kwargs [
154
172
"model_config" ] = _verify_and_get_model_config (
155
173
cpu_draft_worker_kwargs ["model_config" ])
156
- cpu_draft_worker_kwargs [
157
- "scheduler_config" ] = _verify_and_get_scheduler_config (
158
- cpu_draft_worker_kwargs ["scheduler_config" ])
159
-
160
174
cpu_draft_worker_kwargs ["device_config" ].device = torch .device (
161
175
"cpu" )
162
- cpu_draft_worker_kwargs ["device_config" ].device_type = "cpu"
163
176
cpu_draft_worker_kwargs .pop ("observability_config" )
164
- proposer_worker = CPUMultiStepWorker (** cpu_draft_worker_kwargs )
177
+ cls = type ('DynamicClass' , (CPUMultiStepWorker , base_class ), {})
178
+ proposer_worker = cls (** cpu_draft_worker_kwargs )
165
179
elif draft_worker_kwargs [
166
180
"model_config" ].hf_config .model_type == "medusa" :
167
181
proposer_worker = MedusaWorker (** draft_worker_kwargs )
0 commit comments