Skip to content

Commit c87f065

Browse files
committed
async_schedule
Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
1 parent c175bc2 commit c87f065

File tree

4 files changed

+86
-44
lines changed

4 files changed

+86
-44
lines changed

vllm_omni/core/sched/omni_ar_scheduler.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from vllm.v1.spec_decode.metrics import SpecDecodingStats
2121

2222
from vllm_omni.core.sched.output import OmniSchedulerOutput
23-
from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
24-
from vllm_omni.distributed.omni_connectors.transfer_manager.chunk_transfer_manager import OmniChunkTransferManager
25-
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
23+
from vllm_omni.distributed.omni_connectors.transfer_manager.base import OmniModelMode
24+
from vllm_omni.distributed.omni_connectors.transfer_manager.chunk_transfer_manager import (
25+
OmniChunkTransferManager,
26+
)
2627

2728
logger = init_logger(__name__)
2829

@@ -65,17 +66,12 @@ def __init__(self, *args, **kwargs):
6566
# Track requests that have already triggered prefill transfer to avoid duplicates
6667
self.transfer_triggered_requests: set[str] = set()
6768
model_config = self.vllm_config.model_config
68-
self.omni_connector = None
6969
self.chunk_manager = None
70-
if model_config.async_chunk:
71-
connector_config = model_config.stage_connector_config
72-
connector_specs = ConnectorSpec(
73-
name=connector_config.get("name", "SharedMemoryConnector"),
74-
extra=connector_config.get("extra", {}),
75-
)
76-
self.omni_connector = OmniConnectorFactory.create_connector(connector_specs)
77-
self.chunk_manager = OmniChunkTransferManager(self.omni_connector)
70+
if getattr(model_config, "async_chunk", False):
71+
self.chunk_manager = OmniChunkTransferManager(
72+
model_config, OmniModelMode.MODE_AR)
7873

74+
if self.chunk_manager:
7975
custom_process_next_stage_input_func = getattr(
8076
self.vllm_config.model_config, "custom_process_next_stage_input_func", None
8177
)
@@ -192,15 +188,8 @@ def schedule(self) -> SchedulerOutput: # type: ignore[override]
192188
new_list.append(omni_nr)
193189

194190
scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment]
195-
cached_reqs = scheduler_output.scheduled_cached_reqs
196-
if not hasattr(cached_reqs, "additional_information"):
197-
cached_reqs.additional_information = {}
198-
for req_id in cached_reqs.req_ids:
199-
request = self.requests.get(req_id) if req_id else None
200-
additional_info = getattr(request, "additional_information", None) if request else None
201-
cached_reqs.additional_information[req_id] = additional_info
202191
if self.chunk_manager:
203-
self.chunk_manager.filter_scheduler_output(scheduler_output)
192+
self.chunk_manager.filter_scheduler_output(scheduler_output, self.requests)
204193
# Add information about requests needing KV cache transfer
205194
finished_reqs = self.get_finished_requests_needing_kv_transfer()
206195
except Exception:

vllm_omni/core/sched/omni_generation_scheduler.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,21 @@
1616
from vllm.v1.spec_decode.metrics import SpecDecodingStats
1717

1818
from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData
19-
from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
20-
from vllm_omni.distributed.omni_connectors.transfer_manager.chunk_transfer_manager import OmniChunkTransferManager
21-
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
19+
from vllm_omni.distributed.omni_connectors.transfer_manager.base import OmniModelMode
20+
from vllm_omni.distributed.omni_connectors.transfer_manager.chunk_transfer_manager import (
21+
OmniChunkTransferManager,
22+
)
2223
from vllm_omni.outputs import OmniModelRunnerOutput
2324

2425

2526
class OmniGenerationScheduler(VLLMScheduler):
2627
def __init__(self, *args, **kwargs):
2728
super().__init__(*args, **kwargs)
2829
model_config = self.vllm_config.model_config
29-
self.omni_connector = None
3030
self.chunk_manager = None
31-
if model_config.async_chunk:
32-
connector_config = model_config.stage_connector_config
33-
connector_specs = ConnectorSpec(
34-
name=connector_config.get("name", "SharedMemoryConnector"),
35-
extra=connector_config.get("extra", {}),
36-
)
37-
self.omni_connector = OmniConnectorFactory.create_connector(connector_specs)
38-
self.chunk_manager = OmniChunkTransferManager(self.omni_connector)
31+
if getattr(model_config, "async_chunk", False):
32+
self.chunk_manager = OmniChunkTransferManager(
33+
model_config, OmniModelMode.MODE_GENERATION)
3934

4035
self.stage_id = getattr(self.vllm_config.model_config, "stage_id", None)
4136

@@ -76,7 +71,7 @@ def schedule(self) -> SchedulerOutput:
7671
# OMNI: Skip requests that are not in self.requests
7772
# This can happen when connector marks request as finished and it's removed from requests
7873
if request.request_id not in self.requests or (
79-
self.omni_connector is None and request.status == RequestStatus.FINISHED_STOPPED
74+
self.chunk_manager is None and request.status == RequestStatus.FINISHED_STOPPED
8075
):
8176
already_finished_reqs.add(request)
8277
req_index += 1
@@ -115,7 +110,7 @@ def schedule(self) -> SchedulerOutput:
115110
request = self.waiting.peek_request()
116111
# OMNI: Skip requests that are not in self.requests
117112
if request.request_id not in self.requests or (
118-
self.omni_connector is None and request.status == RequestStatus.FINISHED_STOPPED
113+
self.chunk_manager is None and request.status == RequestStatus.FINISHED_STOPPED
119114
):
120115
# Pop the finished request from waiting queue and don't schedule it
121116
self.waiting.pop_request()
@@ -367,7 +362,7 @@ def update_from_output(
367362

368363
# Diffusion request: completes in one step; mark finished and free resources
369364
if request.status == RequestStatus.FINISHED_STOPPED or (
370-
self.omni_connector is None and request.num_computed_tokens >= request.num_prompt_tokens
365+
self.chunk_manager is None and request.num_computed_tokens >= request.num_prompt_tokens
371366
):
372367
request.status = RequestStatus.FINISHED_STOPPED
373368
# Optional: set a stop_reason for front-end clarity

vllm_omni/distributed/omni_connectors/transfer_manager/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import enum
45
import threading
56
import time
7+
from typing import Any
68

79
from ..utils.logging import get_connector_logger
810

911
logger = get_connector_logger(__name__)
1012

1113

14+
class OmniModelMode(enum.Enum):
15+
# Omni AR Model
16+
MODE_AR = "ar"
17+
18+
# Omni Generation Model
19+
MODE_GENERATION = "generate"
20+
21+
1222
class OmniTransferManagerBase:
1323
"""Base class for managing asynchronous data transfer via OmniConnector.
1424
1525
This class handles the core loop logic and connector interactions, but
1626
leaves the specific data processing (chunks, KV cache, etc.) to subclasses.
1727
"""
1828

19-
def __init__(self, connector):
20-
self.connector = connector
29+
def __init__(self, config: Any, mode: Any):
30+
self.config = config
31+
if not hasattr(self, "connector"):
32+
self.connector = None
2133
# Requests that are waiting to be polled
2234
self._pending_load_reqs = {}
2335
# Requests that have successfully retrieved data
@@ -37,6 +49,10 @@ def __init__(self, connector):
3749
self.save_thread = threading.Thread(target=self.save_loop, daemon=True)
3850
self.save_thread.start()
3951

52+
@classmethod
53+
def create_connector(cls, model_config: Any):
54+
raise NotImplementedError
55+
4056
def recv_loop(self):
4157
"""Loop to poll for incoming data."""
4258
while not self.stop_event.is_set():

vllm_omni/distributed/omni_connectors/transfer_manager/chunk_transfer_manager.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,21 @@
77
import torch
88
from vllm.v1.request import Request, RequestStatus
99

10+
from ..factory import OmniConnectorFactory
11+
from ..utils.config import ConnectorSpec
1012
from ..utils.logging import get_connector_logger
11-
from .base import OmniTransferManagerBase
13+
from .base import OmniModelMode, OmniTransferManagerBase
1214

1315
logger = get_connector_logger(__name__)
1416

1517

1618
class OmniChunkTransferManager(OmniTransferManagerBase):
1719
"""Manages asynchronous retrieval and storage of data chunks via OmniConnector."""
1820

19-
def __init__(self, connector):
20-
super().__init__(connector)
21-
21+
def __init__(self, model_config: Any, mode: OmniModelMode):
22+
self.connector = self.create_connector(model_config)
23+
self.model_mode = mode
24+
super().__init__(model_config)
2225
# State specific to Chunk management
2326
self.put_requests: dict[str, int] = defaultdict(int)
2427
self.get_requests: dict[str, int] = defaultdict(int)
@@ -32,6 +35,23 @@ def __init__(self, connector):
3235
self.waiting_for_chunk_running_requests: deque[Any] = deque()
3336
self.requests_with_ready_chunks = set()
3437

38+
@classmethod
39+
def create_connector(cls, model_config: Any):
40+
connector_config = getattr(model_config, "stage_connector_config", None)
41+
if connector_config is None:
42+
connector_config = {}
43+
elif not isinstance(connector_config, dict):
44+
connector_config = {
45+
"name": getattr(connector_config, "name", None),
46+
"extra": getattr(connector_config, "extra", {}),
47+
}
48+
49+
connector_specs = ConnectorSpec(
50+
name=connector_config.get("name", "SharedMemoryConnector"),
51+
extra=connector_config.get("extra", {}),
52+
)
53+
return OmniConnectorFactory.create_connector(connector_specs)
54+
3555
def load(self, request):
3656
"""Request to retrieve a chunk of data for a specific request.
3757
@@ -115,7 +135,7 @@ def _process_single_recv(self, req_id: str):
115135
self.get_requests[req_id] += 1
116136
req = self._pending_load_reqs[req_id]
117137

118-
if stage_id != 2:
138+
if self.mode == OmniModelMode.MODE_AR:
119139
self._update_request_payload(external_req_id, payload_data)
120140
req.additional_information = payload_data
121141
if payload_data.get("finished"):
@@ -211,12 +231,33 @@ def restore_queues(self, waiting_queue: Any, running_queue: list[Request]) -> No
211231
running_queue.extend(self.waiting_for_chunk_running_requests)
212232
self.waiting_for_chunk_running_requests = deque()
213233

214-
def filter_scheduler_output(self, scheduler_output: Any) -> None:
234+
def filter_scheduler_output(
235+
self,
236+
scheduler_output: Any,
237+
requests: dict[str, Request] | None = None,
238+
) -> None:
215239
"""
216-
Clean up ready chunks from scheduler output.
240+
Add addtitional info for cached requests and
241+
clean up ready chunks from scheduler output.
217242
"""
243+
if requests is not None:
244+
self.attach_cached_additional_information(scheduler_output, requests)
218245
self._clear_chunk_ready(scheduler_output)
219246

247+
@staticmethod
248+
def attach_cached_additional_information(
249+
scheduler_output: Any, requests: dict[str, Request]
250+
) -> None:
251+
cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None)
252+
if not cached_reqs:
253+
return
254+
if not hasattr(cached_reqs, "additional_information"):
255+
cached_reqs.additional_information = {}
256+
for req_id in cached_reqs.req_ids:
257+
request = requests.get(req_id) if req_id else None
258+
additional_info = getattr(request, "additional_information", None) if request else None
259+
cached_reqs.additional_information[req_id] = additional_info
260+
220261
def _process_chunk_queue(
221262
self,
222263
queue: Any,
@@ -232,6 +273,7 @@ def _process_chunk_queue(
232273
# Access finished_requests from self instead of connector
233274
if request.request_id in self.finished_requests:
234275
request.additional_information = {}
276+
self.finished_requests.remove(request.request_id)
235277
continue
236278
self.load(request)
237279
request.status = RequestStatus.WAITING_FOR_CHUNK

0 commit comments

Comments
 (0)