[WIP]feat: support flexible inference path for omni model#241
[WIP]feat: support flexible inference path for omni model#241Hangzhi wants to merge 1 commit intosgl-project:mainfrom
Conversation
Allow per-request control over which output modalities are generated in the speech pipeline. When a request specifies modalities=["text"], the audio stages (talker_ar, code_predictor, code2wav) are skipped entirely instead of running and being filtered at the API layer.
There was a problem hiding this comment.
Pull request overview
Adds per-request modality-aware routing to the Qwen3-Omni speech pipeline so text-only requests skip the audio stages (and can complete once decode finishes), while mixed text+audio requests keep the existing behavior.
Changes:
- Introduces
thinker_next_flexiblerouting based onrequest.metadata["output_modalities"]. - Adds per-request terminal stage overrides in the coordinator so text-only requests can resolve after
decode. - Suppresses streaming/bootstrap/stream-done signals toward audio targets for text-only requests to prevent per-request queue/pending-data leaks.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
sglang_omni/models/qwen3_omni/pipeline/next_stage.py |
Adds modality-aware thinker fan-out (thinker_next_flexible). |
sglang_omni/models/qwen3_omni/config.py |
Switches speech pipeline thinker routing to the flexible function. |
sglang_omni/pipeline/coordinator.py |
Tracks per-request expected terminal stages; updates stream/completion resolution logic accordingly. |
sglang_omni/pipeline/worker/runtime.py |
Suppresses streaming/bootstrap and stream-done toward audio targets for text-only requests. |
tests/test_flexible_inference.py |
Adds unit tests for flexible routing and per-request terminal resolution. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| finally: | ||
| self._stream_queues.pop(request_id, None) | ||
| self._completion_futures.pop(request_id, None) | ||
| self._request_terminals.pop(request_id, None) |
There was a problem hiding this comment.
In stream(), the finally block unconditionally pops self._request_terminals[request_id]. If a client cancels/closes the stream generator early (e.g. disconnect), this removes the per-request terminal override while the request may still be running. For text-only requests in the speech pipeline, subsequent _handle_completion() calls will then fall back to the global terminal set (e.g. {"decode","code2wav"}) and can wait forever for code2wav that will never run, leaking _requests/_partial_results. Fix: don’t clear _request_terminals here (let _handle_completion()/abort() own cleanup), or abort the request before clearing to guarantee completion paths still resolve.
| self._request_terminals.pop(request_id, None) |
| # Compute per-request terminal stages when the multi-terminal speech | ||
| # pipeline is active but this particular request only needs text. | ||
| if len(self._terminal_stages) > 1: | ||
| modalities = request.metadata.get("output_modalities") | ||
| if modalities is not None and "audio" not in modalities: | ||
| # Only the text decode terminal is expected. | ||
| text_terminals = { | ||
| s for s in self._terminal_stages if s == "decode" | ||
| } | ||
| if text_terminals: | ||
| self._request_terminals[request_id] = text_terminals |
There was a problem hiding this comment.
The per-request terminal override logic is currently hard-coded to the string stage name "decode". Since Coordinator is pipeline-generic, this creates a hidden coupling to one specific pipeline naming scheme and will silently fail to override terminals in any pipeline whose text terminal is named differently. Consider deriving the text terminal from configuration (e.g., a constant exported by the pipeline config, or a mapping from modalities->terminal stage set) rather than embedding the stage name here.
| # Skip streaming to targets suppressed for this request (e.g. | ||
| # text-only request skipping audio stages in the speech pipeline). | ||
| suppressed = self._suppressed_stream_targets.get(request_id) | ||
| if suppressed and target_stage in suppressed: | ||
| return |
There was a problem hiding this comment.
Stream suppression is implemented for bootstrap and normal chunk enqueues, but error propagation still calls _notify_stream_error() which enqueues error signals to all _stream_targets. For a text-only request where audio targets are suppressed (and never receive a submit/data_ready), these error signals can be buffered in the downstream stage’s _pending_stream_data, recreating the kind of per-request leak this change is trying to avoid. Fix: apply the same suppression check when enqueuing done/error signals (e.g., in _enqueue_stream_done/_enqueue_stream_error and/or inside _notify_stream_error) so suppressed targets never receive any stream control messages.
| import asyncio | ||
| import unittest | ||
| from unittest.mock import AsyncMock, MagicMock | ||
|
|
There was a problem hiding this comment.
The test suite in this repo primarily uses pytest (fixtures, @pytest.mark.asyncio, etc.), but this new test file is written as unittest.TestCase. Pytest will run it, but it’s inconsistent with the existing test conventions and makes async testing/loop management more awkward. Consider rewriting these tests using pytest + pytest-asyncio for consistency with the rest of tests/.
| asyncio.get_event_loop().run_until_complete( | ||
| coord._submit_request("req-text", request) | ||
| ) |
There was a problem hiding this comment.
These tests use asyncio.get_event_loop().run_until_complete(...). On newer Python versions this pattern is deprecated and can be flaky depending on event-loop policy; the rest of the repo generally uses asyncio.run(...) or @pytest.mark.asyncio async tests. Switching to asyncio.run(coord._submit_request(...)) (or making the test async under pytest) would avoid deprecation warnings and make loop ownership explicit.
Summary
Closes #61.
modalities=["text"]now skip audio stages entirely (talker_ar, code_predictor, code2wav) instead of running them and filtering at the API layermodalities=["text", "audio"]or no modalities specified behave exactly as beforeChanges
next_stage.py): Newthinker_next_flexiblechecks request metadata to decide whether to fan out to audio stagescoordinator.py): Coordinator now tracks expected terminal stages per request so text-only requests resolve after decode without waiting for code2wavruntime.py): Bootstrap, inter-stage streaming, and stream-done signals are suppressed for audio targets on text-only requests to avoid memory leaksconfig.py): Speech pipeline now uses the flexible routing functionTest plan
thinker_next_flexiblewith all modality combinations