|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Regression tests for talker executor error propagation. |
| 3 | +
|
| 4 | +Covers both Ming-Omni and Qwen3-Omni talker TTS paths, ensuring common |
| 5 | +exceptions surface to callers instead of being swallowed into silent |
| 6 | +None-waveform fallbacks. |
| 7 | +
|
| 8 | +Reference: https://github.com/sgl-project/sglang-omni/issues/300 |
| 9 | +
|
| 10 | +Author: |
| 11 | +Xuesong Ye https://github.com/yxs |
| 12 | +""" |
| 13 | + |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +from unittest.mock import MagicMock, patch |
| 17 | + |
| 18 | +import pytest |
| 19 | +import torch |
| 20 | + |
| 21 | +from sglang_omni.models.ming_omni.components.talker_executor import MingTalkerExecutor |
| 22 | +from sglang_omni.models.qwen3_omni.components import talker_executor as qwen3_te |
| 23 | +from sglang_omni.proto import OmniRequest, StagePayload |
| 24 | + |
| 25 | +_INJECTED_ERRORS = [ |
| 26 | + torch.OutOfMemoryError("CUDA OOM (injected)"), |
| 27 | + RuntimeError("runtime failure (injected)"), |
| 28 | + ValueError("invalid value (injected)"), |
| 29 | +] |
| 30 | + |
| 31 | + |
| 32 | +def _error_id(exc: BaseException) -> str: |
| 33 | + return type(exc).__name__ |
| 34 | + |
| 35 | + |
| 36 | +@pytest.mark.parametrize("exc", _INJECTED_ERRORS, ids=_error_id) |
| 37 | +@pytest.mark.asyncio |
| 38 | +async def test_ming_talker_propagates_errors(exc: BaseException) -> None: |
| 39 | + executor = MingTalkerExecutor(model_path="/fake/model/path") |
| 40 | + payload = StagePayload( |
| 41 | + request_id="t1", |
| 42 | + request=MagicMock(spec=OmniRequest), |
| 43 | + data={}, |
| 44 | + ) |
| 45 | + |
| 46 | + with ( |
| 47 | + patch.object(executor, "_extract_text", return_value="hello world"), |
| 48 | + patch.object(executor, "_generate_speech", side_effect=exc), |
| 49 | + pytest.raises(type(exc)), |
| 50 | + ): |
| 51 | + await executor.add_request(payload) |
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.parametrize("exc", _INJECTED_ERRORS, ids=_error_id) |
| 55 | +def test_qwen3_talker_propagates_errors(exc: BaseException) -> None: |
| 56 | + executor = MagicMock(spec=qwen3_te.TalkerStreamingExecutor) |
| 57 | + executor._tts_special_cache = None |
| 58 | + executor._tts_bos_token_id = 0 |
| 59 | + executor._tts_eos_token_id = 1 |
| 60 | + executor._tts_pad_token_id = 2 |
| 61 | + executor._resolved_model_path = "/fake/model/path" |
| 62 | + |
| 63 | + with ( |
| 64 | + patch.object( |
| 65 | + qwen3_te, |
| 66 | + "_load_thinker_embedding_rows", |
| 67 | + side_effect=exc, |
| 68 | + ), |
| 69 | + pytest.raises(type(exc)), |
| 70 | + ): |
| 71 | + qwen3_te.TalkerStreamingExecutor._get_tts_special_embeds(executor) |
0 commit comments