diff --git a/vllm_rbln/platform.py b/vllm_rbln/platform.py index b2415c386..8ea0047ec 100644 --- a/vllm_rbln/platform.py +++ b/vllm_rbln/platform.py @@ -135,8 +135,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config.worker_cls == "auto": parallel_config.worker_cls = ( "vllm_rbln.v1.worker.rbln_worker.RBLNWorker") - scheduler_config.scheduler_cls = ( - "vllm_rbln.v1.core.rbln_scheduler.RBLNScheduler") + if scheduler_config.async_scheduling: + logger.warning( + "async scheduling is incomplete because asynchronous " + "device-to-host memory copies are not yet supported") + scheduler_config.scheduler_cls = ( + "vllm_rbln.v1.core.rbln_async_scheduler.RBLNAsyncScheduler" + ) + else: + scheduler_config.scheduler_cls = ( + "vllm_rbln.v1.core.rbln_scheduler.RBLNScheduler") # FIXME(jiwoo.park) This is a temporary workaround. if model_config.enforce_eager: diff --git a/vllm_rbln/v1/core/rbln_async_scheduler.py b/vllm_rbln/v1/core/rbln_async_scheduler.py new file mode 100644 index 000000000..62e16e7f3 --- /dev/null +++ b/vllm_rbln/v1/core/rbln_async_scheduler.py @@ -0,0 +1,81 @@ +# Copyright 2025 Rebellions Inc. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied from vllm.v1.core.sched.async_scheduler: https://github.com/vllm-project/vllm/blob/v0.10.2/vllm/v1/core/sched/async_scheduler.py +# The only differences are: +# - Inherit RBLNScheduler instead of Scheduler +# - Use vllm-rbln logger + +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request, RequestStatus + +from vllm_rbln.logger import init_logger +from vllm_rbln.v1.core.rbln_scheduler import RBLNScheduler + +logger = init_logger(__name__) + + +class RBLNAsyncScheduler(RBLNScheduler): + + def _update_after_schedule( + self, + scheduler_output: SchedulerOutput, + ) -> None: + super()._update_after_schedule(scheduler_output) + pending_structured_output_tokens = False + spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + for req_id in scheduler_output.num_scheduled_tokens: + request = self.requests[req_id] + pending_structured_output_tokens |= ( + request.use_structured_output + and request.num_output_placeholders > 0) + cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ())) + if (request.num_computed_tokens == request.num_tokens + + request.num_output_placeholders + cur_num_spec_tokens): + # The request will generate a new token plus num_spec_tokens + # in this scheduling step. + request.num_output_placeholders += 1 + cur_num_spec_tokens + # Add placeholders for the new tokens in spec_token_ids. + # We will update the actual spec token ids in the worker + # process. + request.spec_token_ids = [-1] * self.num_spec_tokens + + scheduler_output.pending_structured_output_tokens = ( + pending_structured_output_tokens) + + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + if request.discard_latest_async_tokens: + # If the request is force preempted in reset_prefix_cache, we + # should discard the latest async token. + request.discard_latest_async_tokens = False + return [], False + + status_before_update = request.status + new_token_ids, stopped = super()._update_request_with_output( + request, new_token_ids) + + # Update the number of output placeholders. + request.num_output_placeholders -= len(new_token_ids) + assert request.num_output_placeholders >= 0 + + # Cache the new tokens. Preempted requests should be skipped. + if status_before_update == RequestStatus.RUNNING: + self.kv_cache_manager.cache_blocks( + request, + request.num_computed_tokens - request.num_output_placeholders) + return new_token_ids, stopped diff --git a/vllm_rbln/v1/worker/rbln_model_runner.py b/vllm_rbln/v1/worker/rbln_model_runner.py index 84474e4de..0c1e7ab0b 100644 --- a/vllm_rbln/v1/worker/rbln_model_runner.py +++ b/vllm_rbln/v1/worker/rbln_model_runner.py @@ -110,8 +110,16 @@ def __init__( model_runner_output: ModelRunnerOutput, sampled_token_ids: torch.Tensor, invalid_req_indices: list[int], - async_output_copy_stream: torch.cuda.Stream, + # async_output_copy_stream: torch.cuda.Stream, ): + # TODO(RBLN): We need asynchronous non blocking DtoH memcpy and a way + # to synchronize it. In original gpu vllm code, a copy operation is + # launched asynchronously, and an event is recorded on the same stream. + # Synchronizing on that event guarantees that the copy has completed + # at that point. + # We keep the original GPU code commented out for reference in future + # implementations. + self._model_runner_output = model_runner_output self._invalid_req_indices = invalid_req_indices @@ -129,13 +137,17 @@ def __init__( # self._sampled_token_ids_cpu = self._sampled_token_ids.to( # 'cpu', non_blocking=True) # self._async_copy_ready_event.record() + # TODO(RBLN): Replace this with proper async DtoH memcpy. + self._sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True) def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. This function blocks until the copy is finished. """ - self._async_copy_ready_event.synchronize() + # TODO(RBLN): We need to synchronize DtoH memcpy here. + # self._async_copy_ready_event.synchronize() # Release the device tensor once the copy has completed del self._sampled_token_ids @@ -325,6 +337,7 @@ def __init__( ) self.use_async_scheduling = self.scheduler_config.async_scheduling + # TODO(RBLN): We might need stream to control DtoH memcpy. # self.async_output_copy_stream = torch.cuda.Stream() if \ # self.use_async_scheduling else None @@ -2016,7 +2029,7 @@ def sample_tokens( model_runner_output=output, sampled_token_ids=sampler_output.sampled_token_ids, invalid_req_indices=invalid_req_indices, - async_output_copy_stream=self.async_output_copy_stream, + # async_output_copy_stream=self.async_output_copy_stream, ) def load_model(self) -> None: