Skip to content

Commit 3ed40e3

Browse files
yxszhaochenyang20
andauthored
[Fix] Stop swallowing OOM in Ming-Omni and Qwen3-Omni talker executors (#302)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
1 parent 3e64a6d commit 3ed40e3

3 files changed

Lines changed: 86 additions & 29 deletions

File tree

sglang_omni/models/ming_omni/components/talker_executor.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,21 +185,13 @@ async def add_request(self, payload: StagePayload) -> None:
185185
return
186186

187187
t0 = time.time()
188-
logger.info("[TALKER] Starting TTS generation for %d chars...", len(text))
189-
try:
190-
waveform, sample_rate, duration = await asyncio.to_thread(
191-
self._generate_speech, text
192-
)
193-
logger.info(
194-
"[TALKER] TTS done in %.1fs, audio=%.2fs", time.time() - t0, duration
195-
)
196-
except Exception as e:
197-
logger.error(
198-
"[TALKER] ERROR after %.1fs: %s", time.time() - t0, e, exc_info=True
199-
)
200-
waveform = None
201-
sample_rate = 44100
202-
duration = 0.0
188+
logger.debug(f"[TALKER] Starting TTS generation for {len(text)} chars...")
189+
waveform, sample_rate, duration = await asyncio.to_thread(
190+
self._generate_speech, text
191+
)
192+
logger.debug(
193+
f"[TALKER] TTS done in {time.time() - t0:.1f}s, audio={duration:.2f}s"
194+
)
203195

204196
# Serialize tensor to bytes for cross-process msgpack transport
205197
if waveform is not None:
@@ -267,6 +259,10 @@ def _generate_speech(self, text: str) -> tuple[torch.Tensor | None, int, float]:
267259
268260
Returns:
269261
Tuple of (waveform tensor, sample_rate, duration in seconds).
262+
263+
Note (Xuesong): the (None, 44100, 0.0) returns below for "no supported
264+
generation method" / "no waveforms produced" are pre-existing soft
265+
failures, kept out of #300's OOM-propagation scope. Tracked in #188.
270266
"""
271267
if self._talker is None:
272268
raise RuntimeError("Talker model not loaded")

sglang_omni/models/qwen3_omni/components/talker_executor.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -500,20 +500,10 @@ def _get_tts_special_embeds(
500500
self._tts_eos_token_id,
501501
self._tts_pad_token_id,
502502
]
503-
hidden_size = self._talker_model.config.thinker_hidden_size
504-
try:
505-
thinker_rows = _load_thinker_embedding_rows(
506-
self._resolved_model_path, special_ids
507-
)
508-
thinker_rows = thinker_rows.to(device=self._device, dtype=self._dtype)
509-
except Exception:
510-
logger.exception("Failed to load thinker special token embeddings")
511-
thinker_rows = torch.zeros(
512-
len(special_ids),
513-
hidden_size,
514-
device=self._device,
515-
dtype=self._dtype,
516-
)
503+
thinker_rows = _load_thinker_embedding_rows(
504+
self._resolved_model_path, special_ids
505+
)
506+
thinker_rows = thinker_rows.to(device=self._device, dtype=self._dtype)
517507

518508
projected = self._talker_model.text_projection(thinker_rows)
519509
tts_bos_embed, tts_eos_embed, tts_pad_embed = projected.chunk(3, dim=0)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Regression tests for talker executor error propagation.
3+
4+
Covers both Ming-Omni and Qwen3-Omni talker TTS paths, ensuring common
5+
exceptions surface to callers instead of being swallowed into silent
6+
None-waveform fallbacks.
7+
8+
Reference: https://github.com/sgl-project/sglang-omni/issues/300
9+
10+
Author:
11+
Xuesong Ye https://github.com/yxs
12+
"""
13+
14+
from __future__ import annotations
15+
16+
from unittest.mock import MagicMock, patch
17+
18+
import pytest
19+
import torch
20+
21+
from sglang_omni.models.ming_omni.components.talker_executor import MingTalkerExecutor
22+
from sglang_omni.models.qwen3_omni.components import talker_executor as qwen3_te
23+
from sglang_omni.proto import OmniRequest, StagePayload
24+
25+
_INJECTED_ERRORS = [
26+
torch.OutOfMemoryError("CUDA OOM (injected)"),
27+
RuntimeError("runtime failure (injected)"),
28+
ValueError("invalid value (injected)"),
29+
]
30+
31+
32+
def _error_id(exc: BaseException) -> str:
33+
return type(exc).__name__
34+
35+
36+
@pytest.mark.parametrize("exc", _INJECTED_ERRORS, ids=_error_id)
37+
@pytest.mark.asyncio
38+
async def test_ming_talker_propagates_errors(exc: BaseException) -> None:
39+
executor = MingTalkerExecutor(model_path="/fake/model/path")
40+
payload = StagePayload(
41+
request_id="t1",
42+
request=MagicMock(spec=OmniRequest),
43+
data={},
44+
)
45+
46+
with (
47+
patch.object(executor, "_extract_text", return_value="hello world"),
48+
patch.object(executor, "_generate_speech", side_effect=exc),
49+
pytest.raises(type(exc)),
50+
):
51+
await executor.add_request(payload)
52+
53+
54+
@pytest.mark.parametrize("exc", _INJECTED_ERRORS, ids=_error_id)
55+
def test_qwen3_talker_propagates_errors(exc: BaseException) -> None:
56+
executor = MagicMock(spec=qwen3_te.TalkerStreamingExecutor)
57+
executor._tts_special_cache = None
58+
executor._tts_bos_token_id = 0
59+
executor._tts_eos_token_id = 1
60+
executor._tts_pad_token_id = 2
61+
executor._resolved_model_path = "/fake/model/path"
62+
63+
with (
64+
patch.object(
65+
qwen3_te,
66+
"_load_thinker_embedding_rows",
67+
side_effect=exc,
68+
),
69+
pytest.raises(type(exc)),
70+
):
71+
qwen3_te.TalkerStreamingExecutor._get_tts_special_embeds(executor)

0 commit comments

Comments
 (0)