Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions vllm_rbln/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = (
"vllm_rbln.v1.worker.rbln_worker.RBLNWorker")
scheduler_config.scheduler_cls = (
"vllm_rbln.v1.core.rbln_scheduler.RBLNScheduler")
if scheduler_config.async_scheduling:
logger.warning(
"async scheduling is incomplete because asynchronous "
"device-to-host memory copies are not yet supported")
scheduler_config.scheduler_cls = (
"vllm_rbln.v1.core.rbln_async_scheduler.RBLNAsyncScheduler"
)
else:
scheduler_config.scheduler_cls = (
"vllm_rbln.v1.core.rbln_scheduler.RBLNScheduler")

# FIXME(jiwoo.park) This is a temporary workaround.
if model_config.enforce_eager:
Expand Down
81 changes: 81 additions & 0 deletions vllm_rbln/v1/core/rbln_async_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2025 Rebellions Inc. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copied from vllm.v1.core.sched.async_scheduler: https://github.com/vllm-project/vllm/blob/v0.10.2/vllm/v1/core/sched/async_scheduler.py
# The only differences are:
# - Inherit RBLNScheduler instead of Scheduler
# - Use vllm-rbln logger

from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus

from vllm_rbln.logger import init_logger
from vllm_rbln.v1.core.rbln_scheduler import RBLNScheduler

logger = init_logger(__name__)


class RBLNAsyncScheduler(RBLNScheduler):

def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output
and request.num_output_placeholders > 0)
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if (request.num_computed_tokens == request.num_tokens +
request.num_output_placeholders + cur_num_spec_tokens):
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids.
# We will update the actual spec token ids in the worker
# process.
request.spec_token_ids = [-1] * self.num_spec_tokens

scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens)

def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
if request.discard_latest_async_tokens:
# If the request is force preempted in reset_prefix_cache, we
# should discard the latest async token.
request.discard_latest_async_tokens = False
return [], False

status_before_update = request.status
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids)

# Update the number of output placeholders.
request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0

# Cache the new tokens. Preempted requests should be skipped.
if status_before_update == RequestStatus.RUNNING:
self.kv_cache_manager.cache_blocks(
request,
request.num_computed_tokens - request.num_output_placeholders)
return new_token_ids, stopped
19 changes: 16 additions & 3 deletions vllm_rbln/v1/worker/rbln_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,16 @@ def __init__(
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
# async_output_copy_stream: torch.cuda.Stream,
):
# TODO(RBLN): We need asynchronous non blocking DtoH memcpy and a way
# to synchronize it. In original gpu vllm code, a copy operation is
# launched asynchronously, and an event is recorded on the same stream.
# Synchronizing on that event guarantees that the copy has completed
# at that point.
# We keep the original GPU code commented out for reference in future
# implementations.

self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices

Expand All @@ -129,13 +137,17 @@ def __init__(
# self._sampled_token_ids_cpu = self._sampled_token_ids.to(
# 'cpu', non_blocking=True)
# self._async_copy_ready_event.record()
# TODO(RBLN): Replace this with proper async DtoH memcpy.
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
'cpu', non_blocking=True)

def get_output(self) -> ModelRunnerOutput:
"""Copy the device tensors to the host and return a ModelRunnerOutput.

This function blocks until the copy is finished.
"""
self._async_copy_ready_event.synchronize()
# TODO(RBLN): We need to synchronize DtoH memcpy here.
# self._async_copy_ready_event.synchronize()

# Release the device tensor once the copy has completed
del self._sampled_token_ids
Expand Down Expand Up @@ -325,6 +337,7 @@ def __init__(
)

self.use_async_scheduling = self.scheduler_config.async_scheduling
# TODO(RBLN): We might need stream to control DtoH memcpy.
# self.async_output_copy_stream = torch.cuda.Stream() if \
# self.use_async_scheduling else None

Expand Down Expand Up @@ -2016,7 +2029,7 @@ def sample_tokens(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
# async_output_copy_stream=self.async_output_copy_stream,
)

def load_model(self) -> None:
Expand Down