From 3a53c3b43c2566dd4f848fbf38908cc5ae18fec2 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Fri, 6 Feb 2026 03:29:03 +0800 Subject: [PATCH 1/2] add streaming text input support for qwen3 tts Signed-off-by: lishunyang --- PR_description.md | 97 ++++ examples/online_serving/qwen3_tts/README.md | 83 +++- .../qwen3_tts/streaming_speech_client.py | 235 ++++++++++ .../openai_api/test_serving_speech_stream.py | 430 ++++++++++++++++++ .../openai_api/test_text_splitter.py | 229 ++++++++++ vllm_omni/entrypoints/openai/api_server.py | 26 +- .../entrypoints/openai/protocol/audio.py | 16 + .../entrypoints/openai/serving_speech.py | 196 ++++---- .../openai/serving_speech_stream.py | 242 ++++++++++ vllm_omni/entrypoints/openai/text_splitter.py | 91 ++++ 10 files changed, 1552 insertions(+), 93 deletions(-) create mode 100644 PR_description.md create mode 100644 examples/online_serving/qwen3_tts/streaming_speech_client.py create mode 100644 tests/entrypoints/openai_api/test_serving_speech_stream.py create mode 100644 tests/entrypoints/openai_api/test_text_splitter.py create mode 100644 vllm_omni/entrypoints/openai/serving_speech_stream.py create mode 100644 vllm_omni/entrypoints/openai/text_splitter.py diff --git a/PR_description.md b/PR_description.md new file mode 100644 index 0000000000..027d82f157 --- /dev/null +++ b/PR_description.md @@ -0,0 +1,97 @@ +# [Feature][TTS] Streaming Text Input for Qwen3-TTS via WebSocket + +## Summary + +Add a WebSocket endpoint `/v1/audio/speech/stream` that accepts text input incrementally (e.g., from a real-time STT pipeline), buffers and splits at sentence boundaries, and generates audio per sentence using the existing TTS pipeline. + +This enables real-time text-to-speech workflows where text is produced progressively (speech-to-text, LLM token streaming, live captions) and audio needs to be generated as soon as complete sentences are available, rather than waiting for the entire input. + +**Scope:** Streaming text *input* only. Each sentence produces a complete audio response. Streaming audio *output* (chunked PCM) is tracked separately in PR #1189. + +## Motivation + +The current `/v1/audio/speech` REST endpoint requires the full text upfront. In real-time pipelines (e.g., STT → LLM → TTS), text arrives incrementally. Without streaming input support, clients must either: +1. Wait for the entire text before calling TTS (high latency), or +2. Manually implement sentence buffering and make multiple REST calls (complex, no session state). + +This PR solves both issues with a single WebSocket session that handles buffering, sentence detection, and per-sentence generation automatically. + +## WebSocket Protocol + +**Transport:** WebSocket (industry standard — used by OpenAI Realtime API, ElevenLabs, Azure TTS) + +### Client → Server + +```jsonc +// 1. Session config (sent once, first message) +{"type": "session.config", "voice": "Vivian", "task_type": "CustomVoice", "language": "Auto"} + +// 2. Text chunks (sent incrementally, any number of times) +{"type": "input.text", "text": "Hello, how are you? "} + +// 3. End of input (flushes remaining buffer) +{"type": "input.done"} +``` + +### Server → Client + +```jsonc +// Per-sentence: metadata → binary audio → completion +{"type": "audio.start", "sentence_index": 0, "sentence_text": "Hello, how are you?", "format": "wav"} + +{"type": "audio.done", "sentence_index": 0} + +// Session complete +{"type": "session.done", "total_sentences": 3} + +// Non-fatal error (session continues) +{"type": "error", "message": "..."} +``` + +## Changes + +### New Files + +| File | Description | +|------|-------------| +| `vllm_omni/entrypoints/openai/text_splitter.py` | `SentenceSplitter` — incremental sentence boundary detector. Regex-based splitting at English `.!?` + whitespace and CJK fullwidth `。!?,;`. Configurable `min_sentence_length` (default 2, CJK-friendly). | +| `vllm_omni/entrypoints/openai/serving_speech_stream.py` | `OmniStreamingSpeechHandler` — WebSocket session handler. Manages config validation, idle/config timeouts (30s/10s), per-sentence audio generation, and error resilience (one sentence failure doesn't kill the session). | +| `examples/online_serving/qwen3_tts/streaming_speech_client.py` | Python WebSocket client example. Supports `--simulate-stt` mode (word-by-word with configurable delay), all 3 task types (CustomVoice, VoiceDesign, Base), saves per-sentence audio files. | +| `tests/entrypoints/openai_api/test_text_splitter.py` | Unit tests for `SentenceSplitter`: English/Chinese/mixed splitting, incremental accumulation, flush behavior, edge cases. | +| `tests/entrypoints/openai_api/test_serving_speech_stream.py` | WebSocket integration tests: session lifecycle, multi-sentence, incremental text, flush-on-done, empty input, invalid config, invalid JSON, unknown message types, generation failure recovery. | + +### Modified Files + +| File | Description | +|------|-------------| +| `vllm_omni/entrypoints/openai/serving_speech.py` | **Refactor:** Extracted `_generate_audio_bytes(request) → (bytes, media_type)` from `create_speech()`. The REST endpoint delegates to it; the WebSocket handler reuses it per sentence. No behavior change for existing callers. | +| `vllm_omni/entrypoints/openai/protocol/audio.py` | Added `StreamingSpeechSessionConfig` Pydantic model for WebSocket session configuration (mirrors `OpenAICreateSpeechRequest` fields minus `input`). | +| `vllm_omni/entrypoints/openai/api_server.py` | Added `@router.websocket("/v1/audio/speech/stream")` route and `OmniStreamingSpeechHandler` initialization in `omni_init_app_state()`. | +| `examples/online_serving/qwen3_tts/README.md` | Added streaming text input documentation section with protocol spec, parameters, and usage examples. | + +## Design Decisions + +| Decision | Rationale | +|----------|-----------| +| **WebSocket** (not SSE/HTTP chunked) | Bidirectional: client sends text incrementally, server sends audio back. Industry standard for real-time TTS (OpenAI, ElevenLabs, Azure). | +| **Sentence-level chunking** (not token/word) | Natural speech boundaries produce coherent audio. Avoids artifacts from splitting mid-sentence. Production-ready granularity. | +| **`min_sentence_length=2`** | Prevents splitting on lone punctuation (`.`) while supporting short CJK sentences like `你好!` (3 chars). | +| **`_generate_audio_bytes()` extraction** | Clean separation of concerns. REST endpoint wraps in `Response`; WebSocket sends raw bytes. No code duplication. | +| **Per-sentence error resilience** | If generation fails for one sentence, an error is sent but the session continues for remaining sentences. | +| **Idle + config timeouts** | Prevents resource leaks from abandoned connections (30s idle, 10s for initial config). | + +## Test Plan + +- [ ] `pytest tests/entrypoints/openai_api/test_text_splitter.py` — sentence splitter unit tests +- [ ] `pytest tests/entrypoints/openai_api/test_serving_speech_stream.py` — WebSocket integration tests +- [ ] `pytest tests/entrypoints/openai_api/test_serving_speech.py` — existing REST endpoint tests (verify refactor is non-breaking) +- [ ] Manual test with live server: + ```bash + # Start server + ./examples/online_serving/qwen3_tts/run_server.sh CustomVoice + + # Run streaming client + python examples/online_serving/qwen3_tts/streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." \ + --simulate-stt + ``` diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md index 1c9bd48203..bb6366ed5b 100644 --- a/examples/online_serving/qwen3_tts/README.md +++ b/examples/online_serving/qwen3_tts/README.md @@ -162,9 +162,90 @@ with open("output.wav", "wb") as f: f.write(response.content) ``` +## Streaming Text Input + +The `/v1/audio/speech/stream` WebSocket endpoint accepts text incrementally (e.g., from a real-time STT pipeline), buffers and splits at sentence boundaries, and generates audio per sentence. + +> **Note:** This is streaming *text input* only. Each sentence produces a complete audio response. For streaming audio output, see PR #1189. + +### Quick Start + +```bash +# Send full text (sentences are auto-detected) +python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." + +# Simulate STT: send text word-by-word +python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." \ + --simulate-stt --stt-delay 0.1 + +# VoiceDesign task +python streaming_speech_client.py \ + --text "Today is a great day. The weather is nice." \ + --task-type VoiceDesign \ + --instructions "A cheerful young female voice" +``` + +### WebSocket Protocol + +**Client -> Server:** + +```jsonc +// 1. Session config (sent once first) +{"type": "session.config", "voice": "Vivian", "task_type": "CustomVoice", "language": "Auto"} + +// 2. Text chunks (sent incrementally) +{"type": "input.text", "text": "Hello, how are you? "} + +// 3. End of input (flushes remaining buffer) +{"type": "input.done"} +``` + +**Server -> Client:** + +```jsonc +// Audio metadata (before binary frame) +{"type": "audio.start", "sentence_index": 0, "sentence_text": "Hello, how are you?", "format": "wav"} + +// Binary WebSocket frame: raw audio bytes + +// Per-sentence completion +{"type": "audio.done", "sentence_index": 0} + +// Session complete +{"type": "session.done", "total_sentences": 3} + +// Error (non-fatal, session continues) +{"type": "error", "message": "..."} +``` + +### Session Config Parameters + +All parameters from the REST API are supported: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `voice` | string | None | Speaker voice name | +| `task_type` | string | None | CustomVoice, VoiceDesign, or Base | +| `language` | string | None | Language code | +| `instructions` | string | None | Voice style instructions | +| `response_format` | string | "wav" | Audio format per sentence | +| `speed` | float | 1.0 | Playback speed (0.25-4.0) | +| `max_new_tokens` | int | None | Max tokens per sentence | +| `ref_audio` | string | None | Reference audio (Base task) | +| `ref_text` | string | None | Reference text (Base task) | + +### Sentence Detection + +Text is automatically split at sentence boundaries: +- **English:** `.` `!` `?` followed by whitespace +- **CJK:** fullwidth punctuation `。` `!` `?` `,` `;` + +If text never forms a complete sentence, it is flushed when `input.done` is sent. + ## Limitations -- **No streaming**: Audio is generated completely before being returned. Streaming will be supported after the pipeline is disaggregated (see RFC #938). - **Single request**: Batch processing is not yet optimized for online serving. ## Troubleshooting diff --git a/examples/online_serving/qwen3_tts/streaming_speech_client.py b/examples/online_serving/qwen3_tts/streaming_speech_client.py new file mode 100644 index 0000000000..d595d3d105 --- /dev/null +++ b/examples/online_serving/qwen3_tts/streaming_speech_client.py @@ -0,0 +1,235 @@ +"""WebSocket client for streaming text-input TTS. + +Connects to the /v1/audio/speech/stream endpoint, sends text incrementally +(simulating real-time STT output), and saves per-sentence audio files. + +Usage: + # Send full text at once + python streaming_speech_client.py --text "Hello world. How are you? I am fine." + + # Simulate STT: send text word-by-word with delay + python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." \ + --simulate-stt --stt-delay 0.1 + + # VoiceDesign task + python streaming_speech_client.py \ + --text "Today is a great day. The weather is nice." \ + --task-type VoiceDesign \ + --instructions "A cheerful young female voice" + + # Base task (voice cloning) + python streaming_speech_client.py \ + --text "Hello world. How are you?" \ + --task-type Base \ + --ref-audio /path/to/reference.wav \ + --ref-text "Transcript of reference audio" + +Requirements: + pip install websockets +""" + +import argparse +import asyncio +import json +import os + +try: + import websockets +except ImportError: + print("Please install websockets: pip install websockets") + raise SystemExit(1) + + +async def stream_tts( + url: str, + text: str, + config: dict, + output_dir: str, + simulate_stt: bool = False, + stt_delay: float = 0.1, +) -> None: + """Connect to the streaming TTS endpoint and process audio responses.""" + os.makedirs(output_dir, exist_ok=True) + + async with websockets.connect(url) as ws: + # 1. Send session config + config_msg = {"type": "session.config", **config} + await ws.send(json.dumps(config_msg)) + print(f"Sent session config: {config}") + + # 2. Send text (either all at once or word-by-word) + async def send_text(): + if simulate_stt: + words = text.split(" ") + for i, word in enumerate(words): + chunk = word + (" " if i < len(words) - 1 else "") + await ws.send(json.dumps({ + "type": "input.text", + "text": chunk, + })) + print(f" Sent: {chunk!r}") + await asyncio.sleep(stt_delay) + else: + await ws.send(json.dumps({ + "type": "input.text", + "text": text, + })) + print(f"Sent full text: {text!r}") + + # 3. Signal end of input + await ws.send(json.dumps({"type": "input.done"})) + print("Sent input.done") + + # Run sender and receiver concurrently + sender_task = asyncio.create_task(send_text()) + + response_format = config.get("response_format", "wav") + sentence_count = 0 + + try: + while True: + message = await ws.recv() + + if isinstance(message, bytes): + # Binary frame: audio data + filename = os.path.join( + output_dir, + f"sentence_{sentence_count:03d}.{response_format}", + ) + with open(filename, "wb") as f: + f.write(message) + print(f" Saved audio: {filename} ({len(message)} bytes)") + sentence_count += 1 + else: + # JSON frame + msg = json.loads(message) + msg_type = msg.get("type") + + if msg_type == "audio.start": + print( + f" [sentence {msg['sentence_index']}] " + f"Generating: {msg['sentence_text']!r}" + ) + elif msg_type == "audio.done": + print( + f" [sentence {msg['sentence_index']}] Done" + ) + elif msg_type == "session.done": + print( + f"\nSession complete: {msg['total_sentences']} " + f"sentence(s) generated" + ) + break + elif msg_type == "error": + print(f" ERROR: {msg['message']}") + else: + print(f" Unknown message: {msg}") + finally: + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + + print(f"\nAudio files saved to: {output_dir}/") + + +def main(): + parser = argparse.ArgumentParser( + description="Streaming text-input TTS client" + ) + parser.add_argument( + "--url", + default="ws://localhost:8000/v1/audio/speech/stream", + help="WebSocket endpoint URL", + ) + parser.add_argument( + "--text", + required=True, + help="Text to synthesize", + ) + parser.add_argument( + "--output-dir", + default="streaming_tts_output", + help="Directory to save audio files (default: streaming_tts_output)", + ) + + # Session config options + parser.add_argument("--model", default=None, help="Model name") + parser.add_argument("--voice", default="Vivian", help="Speaker voice") + parser.add_argument( + "--task-type", + default="CustomVoice", + choices=["CustomVoice", "VoiceDesign", "Base"], + help="TTS task type", + ) + parser.add_argument("--language", default="Auto", help="Language") + parser.add_argument( + "--instructions", default=None, help="Voice style instructions" + ) + parser.add_argument( + "--response-format", + default="wav", + choices=["wav", "pcm", "flac", "mp3", "aac", "opus"], + help="Audio format", + ) + parser.add_argument( + "--speed", type=float, default=1.0, help="Playback speed (0.25-4.0)" + ) + parser.add_argument( + "--max-new-tokens", type=int, default=None, help="Max tokens" + ) + + # Base task options + parser.add_argument("--ref-audio", default=None, help="Reference audio") + parser.add_argument("--ref-text", default=None, help="Reference text") + parser.add_argument( + "--x-vector-only-mode", + action="store_true", + default=False, + help="Speaker embedding only mode", + ) + + # STT simulation + parser.add_argument( + "--simulate-stt", + action="store_true", + help="Simulate STT by sending text word-by-word", + ) + parser.add_argument( + "--stt-delay", + type=float, + default=0.1, + help="Delay between words in STT simulation (seconds)", + ) + + args = parser.parse_args() + + # Build session config (only include non-None values) + config = {} + for key in [ + "model", "voice", "task_type", "language", "instructions", + "response_format", "speed", "max_new_tokens", "ref_audio", + "ref_text", + ]: + val = getattr(args, key.replace("-", "_"), None) + if val is not None: + config[key] = val + if args.x_vector_only_mode: + config["x_vector_only_mode"] = True + + asyncio.run( + stream_tts( + url=args.url, + text=args.text, + config=config, + output_dir=args.output_dir, + simulate_stt=args.simulate_stt, + stt_delay=args.stt_delay, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/openai_api/test_serving_speech_stream.py b/tests/entrypoints/openai_api/test_serving_speech_stream.py new file mode 100644 index 0000000000..edf8b16f7b --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_speech_stream.py @@ -0,0 +1,430 @@ +"""Integration tests for the streaming speech WebSocket endpoint.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +import torch +from fastapi import FastAPI +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect + +from vllm_omni.entrypoints.openai.protocol.audio import ( + AudioResponse, + StreamingSpeechSessionConfig, +) +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.serving_speech_stream import ( + OmniStreamingSpeechHandler, +) +from vllm_omni.outputs import OmniRequestOutput + + +def _create_mock_speech_service(): + """Create a mock OmniOpenAIServingSpeech for testing.""" + service = MagicMock(spec=OmniOpenAIServingSpeech) + + # Mock _generate_audio_bytes to return fake WAV data + async def mock_generate(request): + # Return minimal WAV-like bytes and media type + fake_audio = b"RIFF" + b"\x00" * 100 # Fake WAV header + data + return fake_audio, "audio/wav" + + service._generate_audio_bytes = AsyncMock(side_effect=mock_generate) + return service + + +def _build_test_app( + speech_service=None, + idle_timeout=5.0, + config_timeout=5.0, +): + """Build a FastAPI app with the streaming speech WebSocket route.""" + if speech_service is None: + speech_service = _create_mock_speech_service() + + handler = OmniStreamingSpeechHandler( + speech_service=speech_service, + idle_timeout=idle_timeout, + config_timeout=config_timeout, + ) + + app = FastAPI() + + @app.websocket("/v1/audio/speech/stream") + async def ws_endpoint(websocket): + await handler.handle_session(websocket) + + return app, handler, speech_service + + +class TestStreamingSpeechBasicLifecycle: + """Test basic session lifecycle: config -> text -> done.""" + + def test_single_sentence(self): + app, _, service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + # Send config + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + "response_format": "wav", + }) + + # Send text with sentence boundary + ws.send_json({ + "type": "input.text", + "text": "Hello world. ", + }) + + # Receive audio.start + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 0 + assert msg["sentence_text"] == "Hello world." + assert msg["format"] == "wav" + + # Receive binary audio + audio_data = ws.receive_bytes() + assert len(audio_data) > 0 + + # Receive audio.done + msg = ws.receive_json() + assert msg["type"] == "audio.done" + assert msg["sentence_index"] == 0 + + # Send done + ws.send_json({"type": "input.done"}) + + # Receive session.done + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 1 + + def test_multiple_sentences(self): + app, _, service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + }) + + ws.send_json({ + "type": "input.text", + "text": "Hello world. How are you? ", + }) + + # First sentence + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 0 + ws.receive_bytes() # audio + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + # Second sentence + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 1 + ws.receive_bytes() # audio + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 2 + + def test_incremental_text(self): + """Text arrives word-by-word, sentence formed across chunks.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + }) + + # Send text incrementally + ws.send_json({"type": "input.text", "text": "Hello "}) + ws.send_json({"type": "input.text", "text": "world. "}) + + # Now a sentence boundary was hit + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_text"] == "Hello world." + ws.receive_bytes() + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + # Send done + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 1 + + +class TestStreamingSpeechFlush: + """Test flush behavior when text has no sentence boundary.""" + + def test_flush_on_done(self): + """Text without sentence boundary is flushed on input.done.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + }) + + ws.send_json({ + "type": "input.text", + "text": "Hello world without punctuation", + }) + + # No sentence boundary, so nothing generated yet + # Now send done to flush + ws.send_json({"type": "input.done"}) + + # Should get audio for the flushed text + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert "Hello world without punctuation" in msg["sentence_text"] + ws.receive_bytes() + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 1 + + def test_empty_flush(self): + """input.done with empty buffer produces no audio.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + }) + + # Send nothing, just done + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 0 + + +class TestStreamingSpeechErrors: + """Test error handling scenarios.""" + + def test_missing_config(self): + """Sending text before config should produce an error.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + # Send text instead of config + ws.send_json({ + "type": "input.text", + "text": "Hello", + }) + + # Should get error about expecting session.config + msg = ws.receive_json() + assert msg["type"] == "error" + assert "session.config" in msg["message"] + + def test_invalid_json(self): + """Invalid JSON should produce an error but not kill session.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + }) + + # Send invalid JSON + ws.send_text("not json at all") + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Invalid JSON" in msg["message"] + + # Session should still be alive — send done + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + + def test_unknown_message_type(self): + """Unknown message type produces error but doesn't kill session.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + }) + + ws.send_json({"type": "unknown.type"}) + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Unknown message type" in msg["message"] + + ws.send_json({"type": "input.done"}) + msg = ws.receive_json() + assert msg["type"] == "session.done" + + def test_generation_failure_continues_session(self): + """If generation fails for one sentence, session continues.""" + service = _create_mock_speech_service() + + call_count = 0 + + async def mock_generate_with_failure(request): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("Generation failed") + return b"RIFF" + b"\x00" * 100, "audio/wav" + + service._generate_audio_bytes = AsyncMock( + side_effect=mock_generate_with_failure + ) + + app, _, _ = _build_test_app(speech_service=service) + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + }) + + # First sentence will fail + ws.send_json({ + "type": "input.text", + "text": "First sentence. Second sentence. ", + }) + + # First sentence: audio.start -> error -> audio.done + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 0 + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Generation failed" in msg["message"] + + msg = ws.receive_json() + assert msg["type"] == "audio.done" + assert msg["sentence_index"] == 0 + + # Second sentence should succeed + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 1 + ws.receive_bytes() # audio data + msg = ws.receive_json() + assert msg["type"] == "audio.done" + assert msg["sentence_index"] == 1 + + ws.send_json({"type": "input.done"}) + msg = ws.receive_json() + assert msg["type"] == "session.done" + + +class TestStreamingSpeechSessionConfig: + """Test session config validation.""" + + def test_valid_config_all_fields(self): + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "voice": "Vivian", + "task_type": "CustomVoice", + "language": "English", + "instructions": "Speak cheerfully", + "response_format": "mp3", + "speed": 1.5, + "max_new_tokens": 1024, + }) + + ws.send_json({"type": "input.done"}) + msg = ws.receive_json() + assert msg["type"] == "session.done" + + def test_invalid_speed(self): + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "speed": 10.0, # invalid: > 4.0 + }) + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Invalid session config" in msg["message"] + + def test_invalid_response_format(self): + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({ + "type": "session.config", + "response_format": "invalid", + }) + + msg = ws.receive_json() + assert msg["type"] == "error" + + +class TestStreamingSpeechConfigModel: + """Unit tests for StreamingSpeechSessionConfig model.""" + + def test_defaults(self): + config = StreamingSpeechSessionConfig() + assert config.response_format == "wav" + assert config.speed == 1.0 + assert config.voice is None + assert config.task_type is None + + def test_all_fields(self): + config = StreamingSpeechSessionConfig( + model="test", + voice="Vivian", + task_type="CustomVoice", + language="English", + instructions="test", + response_format="mp3", + speed=2.0, + max_new_tokens=1024, + ref_audio="http://example.com/audio.wav", + ref_text="hello", + x_vector_only_mode=True, + ) + assert config.voice == "Vivian" + assert config.speed == 2.0 + assert config.response_format == "mp3" diff --git a/tests/entrypoints/openai_api/test_text_splitter.py b/tests/entrypoints/openai_api/test_text_splitter.py new file mode 100644 index 0000000000..988172e009 --- /dev/null +++ b/tests/entrypoints/openai_api/test_text_splitter.py @@ -0,0 +1,229 @@ +"""Tests for SentenceSplitter used in streaming TTS input.""" + +import pytest + +from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter + + +class TestSentenceSplitterEnglish: + """Tests for English sentence splitting.""" + + def test_single_sentence_no_boundary(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world") + assert result == [] + assert splitter.buffer == "Hello world" + + def test_single_sentence_with_boundary(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world. How are you?") + assert len(result) == 1 + assert result[0] == "Hello world." + + def test_multiple_sentences(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello. How are you? I am fine! ") + assert len(result) == 3 + assert result[0] == "Hello." + assert result[1] == "How are you?" + assert result[2] == "I am fine!" + + def test_exclamation_mark(self): + splitter = SentenceSplitter() + result = splitter.add_text("Wow, that is great! Tell me more.") + assert len(result) == 1 + assert result[0] == "Wow, that is great!" + + def test_question_mark(self): + splitter = SentenceSplitter() + result = splitter.add_text("Can you hear me? I hope so.") + assert len(result) == 1 + assert result[0] == "Can you hear me?" + + +class TestSentenceSplitterChinese: + """Tests for CJK sentence splitting.""" + + def test_chinese_period(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好世界。你好吗") + assert len(result) == 1 + assert result[0] == "你好世界。" + + def test_chinese_exclamation(self): + splitter = SentenceSplitter() + result = splitter.add_text("太好了!谢谢你") + assert len(result) == 1 + assert result[0] == "太好了!" + + def test_chinese_question(self): + splitter = SentenceSplitter() + result = splitter.add_text("你是谁?我是小明") + assert len(result) == 1 + assert result[0] == "你是谁?" + + def test_chinese_comma(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好,世界") + assert len(result) == 1 + assert result[0] == "你好," + + def test_chinese_semicolon(self): + splitter = SentenceSplitter() + result = splitter.add_text("第一点;第二点") + assert len(result) == 1 + assert result[0] == "第一点;" + + def test_chinese_multiple(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好!你好吗?我很好。") + assert len(result) == 3 + assert result[0] == "你好!" + assert result[1] == "你好吗?" + assert result[2] == "我很好。" + + +class TestSentenceSplitterMixed: + """Tests for mixed-language sentence splitting.""" + + def test_mixed_english_chinese(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello世界。How are you? ") + assert len(result) == 2 + assert result[0] == "Hello世界。" + assert result[1] == "How are you?" + + +class TestSentenceSplitterIncremental: + """Tests for incremental (multi-chunk) text input.""" + + def test_accumulation_across_chunks(self): + splitter = SentenceSplitter() + # First chunk: no boundary + result1 = splitter.add_text("Hello ") + assert result1 == [] + + # Second chunk: completes a sentence + result2 = splitter.add_text("world. How") + assert len(result2) == 1 + assert result2[0] == "Hello world." + assert splitter.buffer == "How" + + def test_word_by_word(self): + splitter = SentenceSplitter() + words = ["Hello, ", "how ", "are ", "you? ", "I ", "am ", "fine."] + all_sentences = [] + for word in words: + all_sentences.extend(splitter.add_text(word)) + + assert len(all_sentences) == 1 + assert all_sentences[0] == "Hello, how are you?" + # "I am fine." stays in buffer (no trailing whitespace after period) + + def test_three_chunks(self): + splitter = SentenceSplitter() + splitter.add_text("The quick brown ") + splitter.add_text("fox jumps. ") + result = splitter.add_text("Over the lazy dog. ") + # "The quick brown fox jumps." should have been returned on second chunk + # "Over the lazy dog." on third chunk + assert len(result) == 1 + assert result[0] == "Over the lazy dog." + + +class TestSentenceSplitterFlush: + """Tests for flush behavior.""" + + def test_flush_returns_remaining(self): + splitter = SentenceSplitter() + splitter.add_text("Hello world") + result = splitter.flush() + assert result == "Hello world" + assert splitter.buffer == "" + + def test_flush_empty_buffer(self): + splitter = SentenceSplitter() + result = splitter.flush() + assert result is None + + def test_flush_after_sentence(self): + splitter = SentenceSplitter() + splitter.add_text("Hello world. Remaining text") + result = splitter.flush() + assert result == "Remaining text" + + def test_flush_whitespace_only(self): + splitter = SentenceSplitter() + splitter.add_text("Hello. ") + # "Hello." extracted, buffer is " " + result = splitter.flush() + # Whitespace-only should return None + assert result is None + + def test_flush_clears_buffer(self): + splitter = SentenceSplitter() + splitter.add_text("some text") + splitter.flush() + assert splitter.buffer == "" + # Second flush should return None + assert splitter.flush() is None + + +class TestSentenceSplitterEdgeCases: + """Edge case tests.""" + + def test_empty_input(self): + splitter = SentenceSplitter() + result = splitter.add_text("") + assert result == [] + assert splitter.buffer == "" + + def test_none_like_empty(self): + """Empty string should not affect buffer.""" + splitter = SentenceSplitter() + splitter.add_text("Hello") + splitter.add_text("") + assert splitter.buffer == "Hello" + + def test_only_punctuation(self): + splitter = SentenceSplitter() + result = splitter.add_text(". ") + # "." is 1 char, below default min_sentence_length of 2 + # It will be carried forward + assert result == [] + + def test_min_sentence_length(self): + splitter = SentenceSplitter(min_sentence_length=10) + result = splitter.add_text("Hi. Hello world. ") + # "Hi." is 3 chars (< 10), so it gets carried to "Hello world." + assert len(result) == 1 + assert "Hi." in result[0] + assert "Hello world." in result[0] + + def test_min_sentence_length_zero(self): + splitter = SentenceSplitter(min_sentence_length=0) + result = splitter.add_text("A. B. ") + assert len(result) == 2 + + def test_no_boundary_then_flush(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world how are you") + assert result == [] + flushed = splitter.flush() + assert flushed == "Hello world how are you" + + def test_consecutive_punctuation(self): + splitter = SentenceSplitter() + result = splitter.add_text("Really?! Yes, really. ") + assert len(result) >= 1 + + def test_reuse_after_flush(self): + """Splitter can be reused after flush.""" + splitter = SentenceSplitter() + splitter.add_text("First session.") + splitter.flush() + + result = splitter.add_text("Second session. More text") + assert len(result) == 1 + assert result[0] == "Second session." + assert splitter.buffer == "More text" diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index fb52c7e464..f95b0542cc 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -19,7 +19,7 @@ import httpx import vllm.envs as envs -from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile +from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, WebSocket from fastapi.responses import JSONResponse, StreamingResponse from PIL import Image from starlette.datastructures import State @@ -87,6 +87,7 @@ ) from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt from vllm_omni.lora.request import LoRARequest from vllm_omni.lora.utils import stable_lora_int_id @@ -684,6 +685,10 @@ async def omni_init_app_state( engine_client, state.openai_serving_models, request_logger=request_logger ) + state.openai_streaming_speech = OmniStreamingSpeechHandler( + speech_service=state.openai_serving_speech, + ) + state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -819,6 +824,25 @@ async def list_voices(raw_request: Request): return JSONResponse(content={"voices": speakers}) +@router.websocket("/v1/audio/speech/stream") +async def streaming_speech(websocket: WebSocket): + """WebSocket endpoint for streaming text input TTS. + + Accepts text incrementally, splits at sentence boundaries, and + returns audio per sentence. See serving_speech_stream.py for protocol. + """ + handler = getattr(websocket.app.state, "openai_streaming_speech", None) + if handler is None: + await websocket.accept() + await websocket.send_json({ + "type": "error", + "message": "Streaming speech is not available", + }) + await websocket.close() + return + await handler.handle_session(websocket) + + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index d23460626b..9c8a1a58a3 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -72,3 +72,19 @@ class Config: class AudioResponse(BaseModel): audio_data: bytes | str media_type: str + + +class StreamingSpeechSessionConfig(BaseModel): + """Configuration sent as the first WebSocket message for streaming TTS.""" + + model: str | None = None + voice: str | None = None + task_type: Literal["CustomVoice", "VoiceDesign", "Base"] | None = None + language: str | None = None + instructions: str | None = None + response_format: Literal["wav", "pcm", "flac", "mp3", "aac", "opus"] = "wav" + speed: float | None = Field(default=1.0, ge=0.25, le=4.0) + max_new_tokens: int | None = None + ref_audio: str | None = None + ref_text: str | None = None + x_vector_only_mode: bool | None = None diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index a8bae9e993..617a19ff0b 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -180,6 +180,109 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any return params + async def _generate_audio_bytes( + self, + request: OpenAICreateSpeechRequest, + ) -> tuple[bytes, str]: + """Core TTS generation logic: validate, generate, and encode audio. + + Extracted from create_speech() so it can be reused by the streaming + WebSocket handler for per-sentence generation. + + Args: + request: The speech request with text and parameters. + + Returns: + Tuple of (audio_bytes, media_type). + + Raises: + ValueError: If validation fails or generation produces no output. + """ + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Validate TTS parameters + if self._is_tts_model(): + validation_error = self._validate_tts_request(request) + if validation_error: + raise ValueError(validation_error) + + tts_params = self._build_tts_params(request) + prompt_text = self._build_tts_prompt(request.input) + prompt = { + "prompt": prompt_text, + "additional_information": tts_params, + } + else: + tts_params = {} + prompt = {"prompt": request.input} + + request_id = f"speech-{random_uuid()}" + + logger.info( + "TTS speech request %s: text=%r, task_type=%s", + request_id, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + tts_params.get("task_type", ["unknown"])[0], + ) + + sampling_params_list = self.engine_client.default_sampling_params_list + + generator = self.engine_client.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["audio"], + ) + + final_output: OmniRequestOutput | None = None + async for res in generator: + final_output = res + + if final_output is None: + raise ValueError("No output generated from the model.") + + # Extract audio from output + audio_output = None + if hasattr(final_output, "multimodal_output") and final_output.multimodal_output: + audio_output = final_output.multimodal_output + if not audio_output and hasattr(final_output, "request_output"): + if final_output.request_output and hasattr(final_output.request_output, "multimodal_output"): + audio_output = final_output.request_output.multimodal_output + + audio_key = None + if audio_output: + if "audio" in audio_output: + audio_key = "audio" + elif "model_outputs" in audio_output: + audio_key = "model_outputs" + + if not audio_output or audio_key is None: + raise ValueError("TTS model did not produce audio output.") + + audio_tensor = audio_output[audio_key] + sample_rate = audio_output.get("sr", 24000) + if hasattr(sample_rate, "item"): + sample_rate = sample_rate.item() + + if hasattr(audio_tensor, "float"): + audio_tensor = audio_tensor.float().detach().cpu().numpy() + + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.squeeze() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=int(sample_rate), + response_format=request.response_format or "wav", + speed=request.speed or 1.0, + stream_format=request.stream_format, + base64_encode=False, + ) + + audio_response: AudioResponse = self.create_audio(audio_obj) + return audio_response.audio_data, audio_response.media_type + async def create_speech( self, request: OpenAICreateSpeechRequest, @@ -209,98 +312,9 @@ async def create_speech( logger.error("Error with model %s", error_check_ret) return error_check_ret - if self.engine_client.errored: - raise self.engine_client.dead_error - - request_id = f"speech-{random_uuid()}" - try: - if self._is_tts_model(): - # Validate TTS parameters - validation_error = self._validate_tts_request(request) - if validation_error: - return self.create_error_response(validation_error) - - # Build TTS parameters and prompt - tts_params = self._build_tts_params(request) - prompt_text = self._build_tts_prompt(request.input) - prompt = { - "prompt": prompt_text, - "additional_information": tts_params, - } - else: - # Fallback for unsupported models - tts_params = {} - prompt = {"prompt": request.input} - - logger.info( - "TTS speech request %s: text=%r, task_type=%s", - request_id, - request.input[:50] + "..." if len(request.input) > 50 else request.input, - tts_params.get("task_type", ["unknown"])[0], - ) - - sampling_params_list = self.engine_client.default_sampling_params_list - - generator = self.engine_client.generate( - prompt=prompt, - request_id=request_id, - sampling_params_list=sampling_params_list, - output_modalities=["audio"], - ) - - final_output: OmniRequestOutput | None = None - async for res in generator: - final_output = res - - if final_output is None: - return self.create_error_response("No output generated from the model.") - - # Extract audio from output - # Audio can be in final_output.multimodal_output or final_output.request_output.multimodal_output - # Support both "audio" and "model_outputs" keys for compatibility with different models - audio_output = None - if hasattr(final_output, "multimodal_output") and final_output.multimodal_output: - audio_output = final_output.multimodal_output - if not audio_output and hasattr(final_output, "request_output"): - if final_output.request_output and hasattr(final_output.request_output, "multimodal_output"): - audio_output = final_output.request_output.multimodal_output - - # Check for audio data using either "audio" or "model_outputs" key - audio_key = None - if audio_output: - if "audio" in audio_output: - audio_key = "audio" - elif "model_outputs" in audio_output: - audio_key = "model_outputs" - - if not audio_output or audio_key is None: - return self.create_error_response("TTS model did not produce audio output.") - - audio_tensor = audio_output[audio_key] - sample_rate = audio_output.get("sr", 24000) - if hasattr(sample_rate, "item"): - sample_rate = sample_rate.item() - - # Convert tensor to numpy - if hasattr(audio_tensor, "float"): - audio_tensor = audio_tensor.float().detach().cpu().numpy() - - # Squeeze batch dimension if present, but preserve channel dimension for stereo - if audio_tensor.ndim > 1: - audio_tensor = audio_tensor.squeeze() - - audio_obj = CreateAudio( - audio_tensor=audio_tensor, - sample_rate=int(sample_rate), - response_format=request.response_format or "wav", - speed=request.speed or 1.0, - stream_format=request.stream_format, - base64_encode=False, - ) - - audio_response: AudioResponse = self.create_audio(audio_obj) - return Response(content=audio_response.audio_data, media_type=audio_response.media_type) + audio_data, media_type = await self._generate_audio_bytes(request) + return Response(content=audio_data, media_type=media_type) except asyncio.CancelledError: return self.create_error_response("Client disconnected") diff --git a/vllm_omni/entrypoints/openai/serving_speech_stream.py b/vllm_omni/entrypoints/openai/serving_speech_stream.py new file mode 100644 index 0000000000..be0c447a5d --- /dev/null +++ b/vllm_omni/entrypoints/openai/serving_speech_stream.py @@ -0,0 +1,242 @@ +"""WebSocket handler for streaming text input TTS. + +Accepts text incrementally via WebSocket, buffers and splits at sentence +boundaries, and generates audio per sentence using the existing TTS pipeline. + +Protocol: + Client -> Server: + {"type": "session.config", ...} # Session config (sent once first) + {"type": "input.text", "text": "..."} # Text chunks + {"type": "input.done"} # End of input + + Server -> Client: + {"type": "audio.start", "sentence_index": 0, "sentence_text": "...", "format": "wav"} + + {"type": "audio.done", "sentence_index": 0} + {"type": "session.done", "total_sentences": N} + {"type": "error", "message": "..."} +""" + +import asyncio +import json + +from fastapi import WebSocket, WebSocketDisconnect +from pydantic import ValidationError +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.protocol.audio import ( + OpenAICreateSpeechRequest, + StreamingSpeechSessionConfig, +) +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter + +logger = init_logger(__name__) + +_DEFAULT_IDLE_TIMEOUT = 30.0 # seconds +_DEFAULT_CONFIG_TIMEOUT = 10.0 # seconds + + +class OmniStreamingSpeechHandler: + """Handles WebSocket sessions for streaming text-input TTS. + + Each WebSocket connection is an independent session. Text arrives + incrementally, is split at sentence boundaries, and audio is generated + per sentence using the existing OmniOpenAIServingSpeech pipeline. + + Args: + speech_service: The existing TTS serving instance (reused for + validation and audio generation). + idle_timeout: Max seconds to wait for a message before closing. + config_timeout: Max seconds to wait for the initial session.config. + """ + + def __init__( + self, + speech_service: OmniOpenAIServingSpeech, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + config_timeout: float = _DEFAULT_CONFIG_TIMEOUT, + ) -> None: + self._speech_service = speech_service + self._idle_timeout = idle_timeout + self._config_timeout = config_timeout + + async def handle_session(self, websocket: WebSocket) -> None: + """Main session loop for a single WebSocket connection.""" + await websocket.accept() + + try: + # 1. Wait for session.config + config = await self._receive_config(websocket) + if config is None: + return # Error already sent, connection closing + + splitter = SentenceSplitter() + sentence_index = 0 + + # 2. Receive text chunks until input.done + while True: + try: + raw = await asyncio.wait_for( + websocket.receive_text(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + await self._send_error(websocket, "Idle timeout: no message received") + return + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await self._send_error(websocket, "Invalid JSON message") + continue + + msg_type = msg.get("type") + + if msg_type == "input.text": + text = msg.get("text", "") + sentences = splitter.add_text(text) + for sentence in sentences: + await self._generate_and_send( + websocket, config, sentence, sentence_index + ) + sentence_index += 1 + + elif msg_type == "input.done": + # Flush remaining buffer + remaining = splitter.flush() + if remaining: + await self._generate_and_send( + websocket, config, remaining, sentence_index + ) + sentence_index += 1 + + # Send session.done + await websocket.send_json({ + "type": "session.done", + "total_sentences": sentence_index, + }) + return + + else: + await self._send_error( + websocket, + f"Unknown message type: {msg_type}", + ) + + except WebSocketDisconnect: + logger.info("Streaming speech: client disconnected") + except Exception as e: + logger.exception("Streaming speech session error: %s", e) + try: + await self._send_error(websocket, f"Internal error: {e}") + except Exception: + pass + + async def _receive_config( + self, websocket: WebSocket + ) -> StreamingSpeechSessionConfig | None: + """Wait for and validate the session.config message.""" + try: + raw = await asyncio.wait_for( + websocket.receive_text(), + timeout=self._config_timeout, + ) + except asyncio.TimeoutError: + await self._send_error( + websocket, "Timeout waiting for session.config" + ) + return None + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await self._send_error(websocket, "Invalid JSON in session.config") + return None + + if msg.get("type") != "session.config": + await self._send_error( + websocket, + f"Expected session.config, got: {msg.get('type')}", + ) + return None + + try: + config = StreamingSpeechSessionConfig(**{ + k: v for k, v in msg.items() if k != "type" + }) + except ValidationError as e: + await self._send_error( + websocket, f"Invalid session config: {e}" + ) + return None + + return config + + async def _generate_and_send( + self, + websocket: WebSocket, + config: StreamingSpeechSessionConfig, + sentence_text: str, + sentence_index: int, + ) -> None: + """Generate audio for a single sentence and send it over WebSocket.""" + response_format = config.response_format or "wav" + + # Send audio.start + await websocket.send_json({ + "type": "audio.start", + "sentence_index": sentence_index, + "sentence_text": sentence_text, + "format": response_format, + }) + + try: + # Build a per-sentence request reusing session config + request = OpenAICreateSpeechRequest( + input=sentence_text, + model=config.model, + voice=config.voice, + task_type=config.task_type, + language=config.language, + instructions=config.instructions, + response_format=response_format, + speed=config.speed, + max_new_tokens=config.max_new_tokens, + ref_audio=config.ref_audio, + ref_text=config.ref_text, + x_vector_only_mode=config.x_vector_only_mode, + ) + + audio_bytes, _ = await self._speech_service._generate_audio_bytes( + request + ) + + # Send binary audio frame + await websocket.send_bytes(audio_bytes) + + except Exception as e: + logger.error( + "Generation failed for sentence %d: %s", sentence_index, e + ) + await self._send_error( + websocket, + f"Generation failed for sentence {sentence_index}: {e}", + ) + + # Send audio.done (even on error, so client can track progress) + await websocket.send_json({ + "type": "audio.done", + "sentence_index": sentence_index, + }) + + @staticmethod + async def _send_error(websocket: WebSocket, message: str) -> None: + """Send an error message to the client.""" + try: + await websocket.send_json({ + "type": "error", + "message": message, + }) + except Exception: + pass # Connection may already be closed diff --git a/vllm_omni/entrypoints/openai/text_splitter.py b/vllm_omni/entrypoints/openai/text_splitter.py new file mode 100644 index 0000000000..c229a07c2a --- /dev/null +++ b/vllm_omni/entrypoints/openai/text_splitter.py @@ -0,0 +1,91 @@ +"""Multi-language sentence boundary detector for streaming TTS input. + +Buffers incoming text and splits at sentence boundaries (English and CJK), +yielding complete sentences for audio generation. +""" + +import re + + +# Sentence boundary pattern: +# - English: .!? followed by whitespace or end of string +# - CJK fullwidth: 。!?,; +_SENTENCE_BOUNDARY_RE = re.compile( + r"(?<=[.!?])\s+" # English punctuation followed by whitespace + r"|(?<=[。!?,;])" # CJK fullwidth punctuation +) + + +class SentenceSplitter: + """Incremental sentence splitter for streaming text input. + + Buffers text and yields complete sentences when boundaries are detected. + Designed for TTS pipelines where text arrives incrementally (e.g., from STT). + + Args: + min_sentence_length: Minimum character length for a sentence. + Sentences shorter than this are kept in the buffer to avoid + splitting on abbreviations like "Dr." or "U.S.". + """ + + def __init__(self, min_sentence_length: int = 2) -> None: + self._buffer: str = "" + self._min_sentence_length = min_sentence_length + + @property + def buffer(self) -> str: + """Current buffered text.""" + return self._buffer + + def add_text(self, text: str) -> list[str]: + """Add text to the buffer and return any complete sentences. + + Args: + text: Incoming text chunk. + + Returns: + List of complete sentences extracted from the buffer. + May be empty if no sentence boundary was found. + """ + if not text: + return [] + + self._buffer += text + return self._extract_sentences() + + def flush(self) -> str | None: + """Flush remaining buffered text as a final sentence. + + Returns: + The remaining buffered text (stripped), or None if buffer is empty. + """ + remaining = self._buffer.strip() + self._buffer = "" + return remaining if remaining else None + + def _extract_sentences(self) -> list[str]: + """Split buffer at sentence boundaries, keeping incomplete text buffered.""" + parts = _SENTENCE_BOUNDARY_RE.split(self._buffer) + + if len(parts) <= 1: + # No boundary found — keep everything in buffer + return [] + + sentences: list[str] = [] + carry = "" + # All parts except the last are complete sentences + for i in range(len(parts) - 1): + text = carry + parts[i] + carry = "" + stripped = text.strip() + if len(stripped) >= self._min_sentence_length: + sentences.append(stripped) + elif stripped: + # Too short (e.g. "Dr.") — carry forward to next part + carry = text + # else: empty, skip + + # Last part stays in buffer (may be incomplete) + self._buffer = carry + parts[-1] + + return sentences From c2abb5edaa9fef63545ae19b5dcf6c9e1171488e Mon Sep 17 00:00:00 2001 From: lishunyang Date: Fri, 6 Feb 2026 03:33:31 +0800 Subject: [PATCH 2/2] fix precommit Signed-off-by: lishunyang --- PR_description.md | 97 --------- .../qwen3_tts/streaming_speech_client.py | 65 +++--- .../openai_api/test_serving_speech_stream.py | 191 ++++++++++-------- .../openai_api/test_text_splitter.py | 2 - vllm_omni/entrypoints/openai/api_server.py | 10 +- .../openai/serving_speech_stream.py | 76 ++++--- vllm_omni/entrypoints/openai/text_splitter.py | 1 - 7 files changed, 179 insertions(+), 263 deletions(-) delete mode 100644 PR_description.md diff --git a/PR_description.md b/PR_description.md deleted file mode 100644 index 027d82f157..0000000000 --- a/PR_description.md +++ /dev/null @@ -1,97 +0,0 @@ -# [Feature][TTS] Streaming Text Input for Qwen3-TTS via WebSocket - -## Summary - -Add a WebSocket endpoint `/v1/audio/speech/stream` that accepts text input incrementally (e.g., from a real-time STT pipeline), buffers and splits at sentence boundaries, and generates audio per sentence using the existing TTS pipeline. - -This enables real-time text-to-speech workflows where text is produced progressively (speech-to-text, LLM token streaming, live captions) and audio needs to be generated as soon as complete sentences are available, rather than waiting for the entire input. - -**Scope:** Streaming text *input* only. Each sentence produces a complete audio response. Streaming audio *output* (chunked PCM) is tracked separately in PR #1189. - -## Motivation - -The current `/v1/audio/speech` REST endpoint requires the full text upfront. In real-time pipelines (e.g., STT → LLM → TTS), text arrives incrementally. Without streaming input support, clients must either: -1. Wait for the entire text before calling TTS (high latency), or -2. Manually implement sentence buffering and make multiple REST calls (complex, no session state). - -This PR solves both issues with a single WebSocket session that handles buffering, sentence detection, and per-sentence generation automatically. - -## WebSocket Protocol - -**Transport:** WebSocket (industry standard — used by OpenAI Realtime API, ElevenLabs, Azure TTS) - -### Client → Server - -```jsonc -// 1. Session config (sent once, first message) -{"type": "session.config", "voice": "Vivian", "task_type": "CustomVoice", "language": "Auto"} - -// 2. Text chunks (sent incrementally, any number of times) -{"type": "input.text", "text": "Hello, how are you? "} - -// 3. End of input (flushes remaining buffer) -{"type": "input.done"} -``` - -### Server → Client - -```jsonc -// Per-sentence: metadata → binary audio → completion -{"type": "audio.start", "sentence_index": 0, "sentence_text": "Hello, how are you?", "format": "wav"} - -{"type": "audio.done", "sentence_index": 0} - -// Session complete -{"type": "session.done", "total_sentences": 3} - -// Non-fatal error (session continues) -{"type": "error", "message": "..."} -``` - -## Changes - -### New Files - -| File | Description | -|------|-------------| -| `vllm_omni/entrypoints/openai/text_splitter.py` | `SentenceSplitter` — incremental sentence boundary detector. Regex-based splitting at English `.!?` + whitespace and CJK fullwidth `。!?,;`. Configurable `min_sentence_length` (default 2, CJK-friendly). | -| `vllm_omni/entrypoints/openai/serving_speech_stream.py` | `OmniStreamingSpeechHandler` — WebSocket session handler. Manages config validation, idle/config timeouts (30s/10s), per-sentence audio generation, and error resilience (one sentence failure doesn't kill the session). | -| `examples/online_serving/qwen3_tts/streaming_speech_client.py` | Python WebSocket client example. Supports `--simulate-stt` mode (word-by-word with configurable delay), all 3 task types (CustomVoice, VoiceDesign, Base), saves per-sentence audio files. | -| `tests/entrypoints/openai_api/test_text_splitter.py` | Unit tests for `SentenceSplitter`: English/Chinese/mixed splitting, incremental accumulation, flush behavior, edge cases. | -| `tests/entrypoints/openai_api/test_serving_speech_stream.py` | WebSocket integration tests: session lifecycle, multi-sentence, incremental text, flush-on-done, empty input, invalid config, invalid JSON, unknown message types, generation failure recovery. | - -### Modified Files - -| File | Description | -|------|-------------| -| `vllm_omni/entrypoints/openai/serving_speech.py` | **Refactor:** Extracted `_generate_audio_bytes(request) → (bytes, media_type)` from `create_speech()`. The REST endpoint delegates to it; the WebSocket handler reuses it per sentence. No behavior change for existing callers. | -| `vllm_omni/entrypoints/openai/protocol/audio.py` | Added `StreamingSpeechSessionConfig` Pydantic model for WebSocket session configuration (mirrors `OpenAICreateSpeechRequest` fields minus `input`). | -| `vllm_omni/entrypoints/openai/api_server.py` | Added `@router.websocket("/v1/audio/speech/stream")` route and `OmniStreamingSpeechHandler` initialization in `omni_init_app_state()`. | -| `examples/online_serving/qwen3_tts/README.md` | Added streaming text input documentation section with protocol spec, parameters, and usage examples. | - -## Design Decisions - -| Decision | Rationale | -|----------|-----------| -| **WebSocket** (not SSE/HTTP chunked) | Bidirectional: client sends text incrementally, server sends audio back. Industry standard for real-time TTS (OpenAI, ElevenLabs, Azure). | -| **Sentence-level chunking** (not token/word) | Natural speech boundaries produce coherent audio. Avoids artifacts from splitting mid-sentence. Production-ready granularity. | -| **`min_sentence_length=2`** | Prevents splitting on lone punctuation (`.`) while supporting short CJK sentences like `你好!` (3 chars). | -| **`_generate_audio_bytes()` extraction** | Clean separation of concerns. REST endpoint wraps in `Response`; WebSocket sends raw bytes. No code duplication. | -| **Per-sentence error resilience** | If generation fails for one sentence, an error is sent but the session continues for remaining sentences. | -| **Idle + config timeouts** | Prevents resource leaks from abandoned connections (30s idle, 10s for initial config). | - -## Test Plan - -- [ ] `pytest tests/entrypoints/openai_api/test_text_splitter.py` — sentence splitter unit tests -- [ ] `pytest tests/entrypoints/openai_api/test_serving_speech_stream.py` — WebSocket integration tests -- [ ] `pytest tests/entrypoints/openai_api/test_serving_speech.py` — existing REST endpoint tests (verify refactor is non-breaking) -- [ ] Manual test with live server: - ```bash - # Start server - ./examples/online_serving/qwen3_tts/run_server.sh CustomVoice - - # Run streaming client - python examples/online_serving/qwen3_tts/streaming_speech_client.py \ - --text "Hello world. How are you? I am fine." \ - --simulate-stt - ``` diff --git a/examples/online_serving/qwen3_tts/streaming_speech_client.py b/examples/online_serving/qwen3_tts/streaming_speech_client.py index d595d3d105..894651eb18 100644 --- a/examples/online_serving/qwen3_tts/streaming_speech_client.py +++ b/examples/online_serving/qwen3_tts/streaming_speech_client.py @@ -64,17 +64,25 @@ async def send_text(): words = text.split(" ") for i, word in enumerate(words): chunk = word + (" " if i < len(words) - 1 else "") - await ws.send(json.dumps({ - "type": "input.text", - "text": chunk, - })) + await ws.send( + json.dumps( + { + "type": "input.text", + "text": chunk, + } + ) + ) print(f" Sent: {chunk!r}") await asyncio.sleep(stt_delay) else: - await ws.send(json.dumps({ - "type": "input.text", - "text": text, - })) + await ws.send( + json.dumps( + { + "type": "input.text", + "text": text, + } + ) + ) print(f"Sent full text: {text!r}") # 3. Signal end of input @@ -107,19 +115,11 @@ async def send_text(): msg_type = msg.get("type") if msg_type == "audio.start": - print( - f" [sentence {msg['sentence_index']}] " - f"Generating: {msg['sentence_text']!r}" - ) + print(f" [sentence {msg['sentence_index']}] Generating: {msg['sentence_text']!r}") elif msg_type == "audio.done": - print( - f" [sentence {msg['sentence_index']}] Done" - ) + print(f" [sentence {msg['sentence_index']}] Done") elif msg_type == "session.done": - print( - f"\nSession complete: {msg['total_sentences']} " - f"sentence(s) generated" - ) + print(f"\nSession complete: {msg['total_sentences']} sentence(s) generated") break elif msg_type == "error": print(f" ERROR: {msg['message']}") @@ -136,9 +136,7 @@ async def send_text(): def main(): - parser = argparse.ArgumentParser( - description="Streaming text-input TTS client" - ) + parser = argparse.ArgumentParser(description="Streaming text-input TTS client") parser.add_argument( "--url", default="ws://localhost:8000/v1/audio/speech/stream", @@ -165,21 +163,15 @@ def main(): help="TTS task type", ) parser.add_argument("--language", default="Auto", help="Language") - parser.add_argument( - "--instructions", default=None, help="Voice style instructions" - ) + parser.add_argument("--instructions", default=None, help="Voice style instructions") parser.add_argument( "--response-format", default="wav", choices=["wav", "pcm", "flac", "mp3", "aac", "opus"], help="Audio format", ) - parser.add_argument( - "--speed", type=float, default=1.0, help="Playback speed (0.25-4.0)" - ) - parser.add_argument( - "--max-new-tokens", type=int, default=None, help="Max tokens" - ) + parser.add_argument("--speed", type=float, default=1.0, help="Playback speed (0.25-4.0)") + parser.add_argument("--max-new-tokens", type=int, default=None, help="Max tokens") # Base task options parser.add_argument("--ref-audio", default=None, help="Reference audio") @@ -209,8 +201,15 @@ def main(): # Build session config (only include non-None values) config = {} for key in [ - "model", "voice", "task_type", "language", "instructions", - "response_format", "speed", "max_new_tokens", "ref_audio", + "model", + "voice", + "task_type", + "language", + "instructions", + "response_format", + "speed", + "max_new_tokens", + "ref_audio", "ref_text", ]: val = getattr(args, key.replace("-", "_"), None) diff --git a/tests/entrypoints/openai_api/test_serving_speech_stream.py b/tests/entrypoints/openai_api/test_serving_speech_stream.py index edf8b16f7b..eedb54fa28 100644 --- a/tests/entrypoints/openai_api/test_serving_speech_stream.py +++ b/tests/entrypoints/openai_api/test_serving_speech_stream.py @@ -1,24 +1,17 @@ """Integration tests for the streaming speech WebSocket endpoint.""" -import json -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock -import numpy as np -import pytest -import torch from fastapi import FastAPI from starlette.testclient import TestClient -from starlette.websockets import WebSocketDisconnect from vllm_omni.entrypoints.openai.protocol.audio import ( - AudioResponse, StreamingSpeechSessionConfig, ) from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech from vllm_omni.entrypoints.openai.serving_speech_stream import ( OmniStreamingSpeechHandler, ) -from vllm_omni.outputs import OmniRequestOutput def _create_mock_speech_service(): @@ -68,17 +61,21 @@ def test_single_sentence(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: # Send config - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - "response_format": "wav", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + "response_format": "wav", + } + ) # Send text with sentence boundary - ws.send_json({ - "type": "input.text", - "text": "Hello world. ", - }) + ws.send_json( + { + "type": "input.text", + "text": "Hello world. ", + } + ) # Receive audio.start msg = ws.receive_json() @@ -109,15 +106,19 @@ def test_multiple_sentences(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - }) - - ws.send_json({ - "type": "input.text", - "text": "Hello world. How are you? ", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + ws.send_json( + { + "type": "input.text", + "text": "Hello world. How are you? ", + } + ) # First sentence msg = ws.receive_json() @@ -147,10 +148,12 @@ def test_incremental_text(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) # Send text incrementally ws.send_json({"type": "input.text", "text": "Hello "}) @@ -181,15 +184,19 @@ def test_flush_on_done(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - }) - - ws.send_json({ - "type": "input.text", - "text": "Hello world without punctuation", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + ws.send_json( + { + "type": "input.text", + "text": "Hello world without punctuation", + } + ) # No sentence boundary, so nothing generated yet # Now send done to flush @@ -213,10 +220,12 @@ def test_empty_flush(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) # Send nothing, just done ws.send_json({"type": "input.done"}) @@ -236,10 +245,12 @@ def test_missing_config(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: # Send text instead of config - ws.send_json({ - "type": "input.text", - "text": "Hello", - }) + ws.send_json( + { + "type": "input.text", + "text": "Hello", + } + ) # Should get error about expecting session.config msg = ws.receive_json() @@ -252,10 +263,12 @@ def test_invalid_json(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) # Send invalid JSON ws.send_text("not json at all") @@ -276,10 +289,12 @@ def test_unknown_message_type(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) ws.send_json({"type": "unknown.type"}) @@ -304,24 +319,26 @@ async def mock_generate_with_failure(request): raise RuntimeError("Generation failed") return b"RIFF" + b"\x00" * 100, "audio/wav" - service._generate_audio_bytes = AsyncMock( - side_effect=mock_generate_with_failure - ) + service._generate_audio_bytes = AsyncMock(side_effect=mock_generate_with_failure) app, _, _ = _build_test_app(speech_service=service) with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) # First sentence will fail - ws.send_json({ - "type": "input.text", - "text": "First sentence. Second sentence. ", - }) + ws.send_json( + { + "type": "input.text", + "text": "First sentence. Second sentence. ", + } + ) # First sentence: audio.start -> error -> audio.done msg = ws.receive_json() @@ -358,16 +375,18 @@ def test_valid_config_all_fields(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "voice": "Vivian", - "task_type": "CustomVoice", - "language": "English", - "instructions": "Speak cheerfully", - "response_format": "mp3", - "speed": 1.5, - "max_new_tokens": 1024, - }) + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + "task_type": "CustomVoice", + "language": "English", + "instructions": "Speak cheerfully", + "response_format": "mp3", + "speed": 1.5, + "max_new_tokens": 1024, + } + ) ws.send_json({"type": "input.done"}) msg = ws.receive_json() @@ -378,10 +397,12 @@ def test_invalid_speed(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "speed": 10.0, # invalid: > 4.0 - }) + ws.send_json( + { + "type": "session.config", + "speed": 10.0, # invalid: > 4.0 + } + ) msg = ws.receive_json() assert msg["type"] == "error" @@ -392,10 +413,12 @@ def test_invalid_response_format(self): with TestClient(app) as client: with client.websocket_connect("/v1/audio/speech/stream") as ws: - ws.send_json({ - "type": "session.config", - "response_format": "invalid", - }) + ws.send_json( + { + "type": "session.config", + "response_format": "invalid", + } + ) msg = ws.receive_json() assert msg["type"] == "error" diff --git a/tests/entrypoints/openai_api/test_text_splitter.py b/tests/entrypoints/openai_api/test_text_splitter.py index 988172e009..a4e303b3b8 100644 --- a/tests/entrypoints/openai_api/test_text_splitter.py +++ b/tests/entrypoints/openai_api/test_text_splitter.py @@ -1,7 +1,5 @@ """Tests for SentenceSplitter used in streaming TTS input.""" -import pytest - from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index f95b0542cc..2f8ccf95c2 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -834,10 +834,12 @@ async def streaming_speech(websocket: WebSocket): handler = getattr(websocket.app.state, "openai_streaming_speech", None) if handler is None: await websocket.accept() - await websocket.send_json({ - "type": "error", - "message": "Streaming speech is not available", - }) + await websocket.send_json( + { + "type": "error", + "message": "Streaming speech is not available", + } + ) await websocket.close() return await handler.handle_session(websocket) diff --git a/vllm_omni/entrypoints/openai/serving_speech_stream.py b/vllm_omni/entrypoints/openai/serving_speech_stream.py index be0c447a5d..708ea6d5ac 100644 --- a/vllm_omni/entrypoints/openai/serving_speech_stream.py +++ b/vllm_omni/entrypoints/openai/serving_speech_stream.py @@ -97,25 +97,23 @@ async def handle_session(self, websocket: WebSocket) -> None: text = msg.get("text", "") sentences = splitter.add_text(text) for sentence in sentences: - await self._generate_and_send( - websocket, config, sentence, sentence_index - ) + await self._generate_and_send(websocket, config, sentence, sentence_index) sentence_index += 1 elif msg_type == "input.done": # Flush remaining buffer remaining = splitter.flush() if remaining: - await self._generate_and_send( - websocket, config, remaining, sentence_index - ) + await self._generate_and_send(websocket, config, remaining, sentence_index) sentence_index += 1 # Send session.done - await websocket.send_json({ - "type": "session.done", - "total_sentences": sentence_index, - }) + await websocket.send_json( + { + "type": "session.done", + "total_sentences": sentence_index, + } + ) return else: @@ -133,9 +131,7 @@ async def handle_session(self, websocket: WebSocket) -> None: except Exception: pass - async def _receive_config( - self, websocket: WebSocket - ) -> StreamingSpeechSessionConfig | None: + async def _receive_config(self, websocket: WebSocket) -> StreamingSpeechSessionConfig | None: """Wait for and validate the session.config message.""" try: raw = await asyncio.wait_for( @@ -143,9 +139,7 @@ async def _receive_config( timeout=self._config_timeout, ) except asyncio.TimeoutError: - await self._send_error( - websocket, "Timeout waiting for session.config" - ) + await self._send_error(websocket, "Timeout waiting for session.config") return None try: @@ -162,13 +156,9 @@ async def _receive_config( return None try: - config = StreamingSpeechSessionConfig(**{ - k: v for k, v in msg.items() if k != "type" - }) + config = StreamingSpeechSessionConfig(**{k: v for k, v in msg.items() if k != "type"}) except ValidationError as e: - await self._send_error( - websocket, f"Invalid session config: {e}" - ) + await self._send_error(websocket, f"Invalid session config: {e}") return None return config @@ -184,12 +174,14 @@ async def _generate_and_send( response_format = config.response_format or "wav" # Send audio.start - await websocket.send_json({ - "type": "audio.start", - "sentence_index": sentence_index, - "sentence_text": sentence_text, - "format": response_format, - }) + await websocket.send_json( + { + "type": "audio.start", + "sentence_index": sentence_index, + "sentence_text": sentence_text, + "format": response_format, + } + ) try: # Build a per-sentence request reusing session config @@ -208,35 +200,35 @@ async def _generate_and_send( x_vector_only_mode=config.x_vector_only_mode, ) - audio_bytes, _ = await self._speech_service._generate_audio_bytes( - request - ) + audio_bytes, _ = await self._speech_service._generate_audio_bytes(request) # Send binary audio frame await websocket.send_bytes(audio_bytes) except Exception as e: - logger.error( - "Generation failed for sentence %d: %s", sentence_index, e - ) + logger.error("Generation failed for sentence %d: %s", sentence_index, e) await self._send_error( websocket, f"Generation failed for sentence {sentence_index}: {e}", ) # Send audio.done (even on error, so client can track progress) - await websocket.send_json({ - "type": "audio.done", - "sentence_index": sentence_index, - }) + await websocket.send_json( + { + "type": "audio.done", + "sentence_index": sentence_index, + } + ) @staticmethod async def _send_error(websocket: WebSocket, message: str) -> None: """Send an error message to the client.""" try: - await websocket.send_json({ - "type": "error", - "message": message, - }) + await websocket.send_json( + { + "type": "error", + "message": message, + } + ) except Exception: pass # Connection may already be closed diff --git a/vllm_omni/entrypoints/openai/text_splitter.py b/vllm_omni/entrypoints/openai/text_splitter.py index c229a07c2a..27d30f417e 100644 --- a/vllm_omni/entrypoints/openai/text_splitter.py +++ b/vllm_omni/entrypoints/openai/text_splitter.py @@ -6,7 +6,6 @@ import re - # Sentence boundary pattern: # - English: .!? followed by whitespace or end of string # - CJK fullwidth: 。!?,;