diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md index 1c9bd48203..bb6366ed5b 100644 --- a/examples/online_serving/qwen3_tts/README.md +++ b/examples/online_serving/qwen3_tts/README.md @@ -162,9 +162,90 @@ with open("output.wav", "wb") as f: f.write(response.content) ``` +## Streaming Text Input + +The `/v1/audio/speech/stream` WebSocket endpoint accepts text incrementally (e.g., from a real-time STT pipeline), buffers and splits at sentence boundaries, and generates audio per sentence. + +> **Note:** This is streaming *text input* only. Each sentence produces a complete audio response. For streaming audio output, see PR #1189. + +### Quick Start + +```bash +# Send full text (sentences are auto-detected) +python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." + +# Simulate STT: send text word-by-word +python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." \ + --simulate-stt --stt-delay 0.1 + +# VoiceDesign task +python streaming_speech_client.py \ + --text "Today is a great day. The weather is nice." \ + --task-type VoiceDesign \ + --instructions "A cheerful young female voice" +``` + +### WebSocket Protocol + +**Client -> Server:** + +```jsonc +// 1. Session config (sent once first) +{"type": "session.config", "voice": "Vivian", "task_type": "CustomVoice", "language": "Auto"} + +// 2. Text chunks (sent incrementally) +{"type": "input.text", "text": "Hello, how are you? "} + +// 3. End of input (flushes remaining buffer) +{"type": "input.done"} +``` + +**Server -> Client:** + +```jsonc +// Audio metadata (before binary frame) +{"type": "audio.start", "sentence_index": 0, "sentence_text": "Hello, how are you?", "format": "wav"} + +// Binary WebSocket frame: raw audio bytes + +// Per-sentence completion +{"type": "audio.done", "sentence_index": 0} + +// Session complete +{"type": "session.done", "total_sentences": 3} + +// Error (non-fatal, session continues) +{"type": "error", "message": "..."} +``` + +### Session Config Parameters + +All parameters from the REST API are supported: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `voice` | string | None | Speaker voice name | +| `task_type` | string | None | CustomVoice, VoiceDesign, or Base | +| `language` | string | None | Language code | +| `instructions` | string | None | Voice style instructions | +| `response_format` | string | "wav" | Audio format per sentence | +| `speed` | float | 1.0 | Playback speed (0.25-4.0) | +| `max_new_tokens` | int | None | Max tokens per sentence | +| `ref_audio` | string | None | Reference audio (Base task) | +| `ref_text` | string | None | Reference text (Base task) | + +### Sentence Detection + +Text is automatically split at sentence boundaries: +- **English:** `.` `!` `?` followed by whitespace +- **CJK:** fullwidth punctuation `。` `!` `?` `,` `;` + +If text never forms a complete sentence, it is flushed when `input.done` is sent. + ## Limitations -- **No streaming**: Audio is generated completely before being returned. Streaming will be supported after the pipeline is disaggregated (see RFC #938). - **Single request**: Batch processing is not yet optimized for online serving. ## Troubleshooting diff --git a/examples/online_serving/qwen3_tts/streaming_speech_client.py b/examples/online_serving/qwen3_tts/streaming_speech_client.py new file mode 100644 index 0000000000..894651eb18 --- /dev/null +++ b/examples/online_serving/qwen3_tts/streaming_speech_client.py @@ -0,0 +1,234 @@ +"""WebSocket client for streaming text-input TTS. + +Connects to the /v1/audio/speech/stream endpoint, sends text incrementally +(simulating real-time STT output), and saves per-sentence audio files. + +Usage: + # Send full text at once + python streaming_speech_client.py --text "Hello world. How are you? I am fine." + + # Simulate STT: send text word-by-word with delay + python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." \ + --simulate-stt --stt-delay 0.1 + + # VoiceDesign task + python streaming_speech_client.py \ + --text "Today is a great day. The weather is nice." \ + --task-type VoiceDesign \ + --instructions "A cheerful young female voice" + + # Base task (voice cloning) + python streaming_speech_client.py \ + --text "Hello world. How are you?" \ + --task-type Base \ + --ref-audio /path/to/reference.wav \ + --ref-text "Transcript of reference audio" + +Requirements: + pip install websockets +""" + +import argparse +import asyncio +import json +import os + +try: + import websockets +except ImportError: + print("Please install websockets: pip install websockets") + raise SystemExit(1) + + +async def stream_tts( + url: str, + text: str, + config: dict, + output_dir: str, + simulate_stt: bool = False, + stt_delay: float = 0.1, +) -> None: + """Connect to the streaming TTS endpoint and process audio responses.""" + os.makedirs(output_dir, exist_ok=True) + + async with websockets.connect(url) as ws: + # 1. Send session config + config_msg = {"type": "session.config", **config} + await ws.send(json.dumps(config_msg)) + print(f"Sent session config: {config}") + + # 2. Send text (either all at once or word-by-word) + async def send_text(): + if simulate_stt: + words = text.split(" ") + for i, word in enumerate(words): + chunk = word + (" " if i < len(words) - 1 else "") + await ws.send( + json.dumps( + { + "type": "input.text", + "text": chunk, + } + ) + ) + print(f" Sent: {chunk!r}") + await asyncio.sleep(stt_delay) + else: + await ws.send( + json.dumps( + { + "type": "input.text", + "text": text, + } + ) + ) + print(f"Sent full text: {text!r}") + + # 3. Signal end of input + await ws.send(json.dumps({"type": "input.done"})) + print("Sent input.done") + + # Run sender and receiver concurrently + sender_task = asyncio.create_task(send_text()) + + response_format = config.get("response_format", "wav") + sentence_count = 0 + + try: + while True: + message = await ws.recv() + + if isinstance(message, bytes): + # Binary frame: audio data + filename = os.path.join( + output_dir, + f"sentence_{sentence_count:03d}.{response_format}", + ) + with open(filename, "wb") as f: + f.write(message) + print(f" Saved audio: {filename} ({len(message)} bytes)") + sentence_count += 1 + else: + # JSON frame + msg = json.loads(message) + msg_type = msg.get("type") + + if msg_type == "audio.start": + print(f" [sentence {msg['sentence_index']}] Generating: {msg['sentence_text']!r}") + elif msg_type == "audio.done": + print(f" [sentence {msg['sentence_index']}] Done") + elif msg_type == "session.done": + print(f"\nSession complete: {msg['total_sentences']} sentence(s) generated") + break + elif msg_type == "error": + print(f" ERROR: {msg['message']}") + else: + print(f" Unknown message: {msg}") + finally: + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + + print(f"\nAudio files saved to: {output_dir}/") + + +def main(): + parser = argparse.ArgumentParser(description="Streaming text-input TTS client") + parser.add_argument( + "--url", + default="ws://localhost:8000/v1/audio/speech/stream", + help="WebSocket endpoint URL", + ) + parser.add_argument( + "--text", + required=True, + help="Text to synthesize", + ) + parser.add_argument( + "--output-dir", + default="streaming_tts_output", + help="Directory to save audio files (default: streaming_tts_output)", + ) + + # Session config options + parser.add_argument("--model", default=None, help="Model name") + parser.add_argument("--voice", default="Vivian", help="Speaker voice") + parser.add_argument( + "--task-type", + default="CustomVoice", + choices=["CustomVoice", "VoiceDesign", "Base"], + help="TTS task type", + ) + parser.add_argument("--language", default="Auto", help="Language") + parser.add_argument("--instructions", default=None, help="Voice style instructions") + parser.add_argument( + "--response-format", + default="wav", + choices=["wav", "pcm", "flac", "mp3", "aac", "opus"], + help="Audio format", + ) + parser.add_argument("--speed", type=float, default=1.0, help="Playback speed (0.25-4.0)") + parser.add_argument("--max-new-tokens", type=int, default=None, help="Max tokens") + + # Base task options + parser.add_argument("--ref-audio", default=None, help="Reference audio") + parser.add_argument("--ref-text", default=None, help="Reference text") + parser.add_argument( + "--x-vector-only-mode", + action="store_true", + default=False, + help="Speaker embedding only mode", + ) + + # STT simulation + parser.add_argument( + "--simulate-stt", + action="store_true", + help="Simulate STT by sending text word-by-word", + ) + parser.add_argument( + "--stt-delay", + type=float, + default=0.1, + help="Delay between words in STT simulation (seconds)", + ) + + args = parser.parse_args() + + # Build session config (only include non-None values) + config = {} + for key in [ + "model", + "voice", + "task_type", + "language", + "instructions", + "response_format", + "speed", + "max_new_tokens", + "ref_audio", + "ref_text", + ]: + val = getattr(args, key.replace("-", "_"), None) + if val is not None: + config[key] = val + if args.x_vector_only_mode: + config["x_vector_only_mode"] = True + + asyncio.run( + stream_tts( + url=args.url, + text=args.text, + config=config, + output_dir=args.output_dir, + simulate_stt=args.simulate_stt, + stt_delay=args.stt_delay, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/openai_api/test_serving_speech_stream.py b/tests/entrypoints/openai_api/test_serving_speech_stream.py new file mode 100644 index 0000000000..eedb54fa28 --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_speech_stream.py @@ -0,0 +1,453 @@ +"""Integration tests for the streaming speech WebSocket endpoint.""" + +from unittest.mock import AsyncMock, MagicMock + +from fastapi import FastAPI +from starlette.testclient import TestClient + +from vllm_omni.entrypoints.openai.protocol.audio import ( + StreamingSpeechSessionConfig, +) +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.serving_speech_stream import ( + OmniStreamingSpeechHandler, +) + + +def _create_mock_speech_service(): + """Create a mock OmniOpenAIServingSpeech for testing.""" + service = MagicMock(spec=OmniOpenAIServingSpeech) + + # Mock _generate_audio_bytes to return fake WAV data + async def mock_generate(request): + # Return minimal WAV-like bytes and media type + fake_audio = b"RIFF" + b"\x00" * 100 # Fake WAV header + data + return fake_audio, "audio/wav" + + service._generate_audio_bytes = AsyncMock(side_effect=mock_generate) + return service + + +def _build_test_app( + speech_service=None, + idle_timeout=5.0, + config_timeout=5.0, +): + """Build a FastAPI app with the streaming speech WebSocket route.""" + if speech_service is None: + speech_service = _create_mock_speech_service() + + handler = OmniStreamingSpeechHandler( + speech_service=speech_service, + idle_timeout=idle_timeout, + config_timeout=config_timeout, + ) + + app = FastAPI() + + @app.websocket("/v1/audio/speech/stream") + async def ws_endpoint(websocket): + await handler.handle_session(websocket) + + return app, handler, speech_service + + +class TestStreamingSpeechBasicLifecycle: + """Test basic session lifecycle: config -> text -> done.""" + + def test_single_sentence(self): + app, _, service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + # Send config + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + "response_format": "wav", + } + ) + + # Send text with sentence boundary + ws.send_json( + { + "type": "input.text", + "text": "Hello world. ", + } + ) + + # Receive audio.start + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 0 + assert msg["sentence_text"] == "Hello world." + assert msg["format"] == "wav" + + # Receive binary audio + audio_data = ws.receive_bytes() + assert len(audio_data) > 0 + + # Receive audio.done + msg = ws.receive_json() + assert msg["type"] == "audio.done" + assert msg["sentence_index"] == 0 + + # Send done + ws.send_json({"type": "input.done"}) + + # Receive session.done + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 1 + + def test_multiple_sentences(self): + app, _, service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + ws.send_json( + { + "type": "input.text", + "text": "Hello world. How are you? ", + } + ) + + # First sentence + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 0 + ws.receive_bytes() # audio + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + # Second sentence + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 1 + ws.receive_bytes() # audio + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 2 + + def test_incremental_text(self): + """Text arrives word-by-word, sentence formed across chunks.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + # Send text incrementally + ws.send_json({"type": "input.text", "text": "Hello "}) + ws.send_json({"type": "input.text", "text": "world. "}) + + # Now a sentence boundary was hit + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_text"] == "Hello world." + ws.receive_bytes() + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + # Send done + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 1 + + +class TestStreamingSpeechFlush: + """Test flush behavior when text has no sentence boundary.""" + + def test_flush_on_done(self): + """Text without sentence boundary is flushed on input.done.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + ws.send_json( + { + "type": "input.text", + "text": "Hello world without punctuation", + } + ) + + # No sentence boundary, so nothing generated yet + # Now send done to flush + ws.send_json({"type": "input.done"}) + + # Should get audio for the flushed text + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert "Hello world without punctuation" in msg["sentence_text"] + ws.receive_bytes() + msg = ws.receive_json() + assert msg["type"] == "audio.done" + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 1 + + def test_empty_flush(self): + """input.done with empty buffer produces no audio.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + # Send nothing, just done + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + assert msg["total_sentences"] == 0 + + +class TestStreamingSpeechErrors: + """Test error handling scenarios.""" + + def test_missing_config(self): + """Sending text before config should produce an error.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + # Send text instead of config + ws.send_json( + { + "type": "input.text", + "text": "Hello", + } + ) + + # Should get error about expecting session.config + msg = ws.receive_json() + assert msg["type"] == "error" + assert "session.config" in msg["message"] + + def test_invalid_json(self): + """Invalid JSON should produce an error but not kill session.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + # Send invalid JSON + ws.send_text("not json at all") + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Invalid JSON" in msg["message"] + + # Session should still be alive — send done + ws.send_json({"type": "input.done"}) + + msg = ws.receive_json() + assert msg["type"] == "session.done" + + def test_unknown_message_type(self): + """Unknown message type produces error but doesn't kill session.""" + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + ws.send_json({"type": "unknown.type"}) + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Unknown message type" in msg["message"] + + ws.send_json({"type": "input.done"}) + msg = ws.receive_json() + assert msg["type"] == "session.done" + + def test_generation_failure_continues_session(self): + """If generation fails for one sentence, session continues.""" + service = _create_mock_speech_service() + + call_count = 0 + + async def mock_generate_with_failure(request): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("Generation failed") + return b"RIFF" + b"\x00" * 100, "audio/wav" + + service._generate_audio_bytes = AsyncMock(side_effect=mock_generate_with_failure) + + app, _, _ = _build_test_app(speech_service=service) + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + } + ) + + # First sentence will fail + ws.send_json( + { + "type": "input.text", + "text": "First sentence. Second sentence. ", + } + ) + + # First sentence: audio.start -> error -> audio.done + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 0 + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Generation failed" in msg["message"] + + msg = ws.receive_json() + assert msg["type"] == "audio.done" + assert msg["sentence_index"] == 0 + + # Second sentence should succeed + msg = ws.receive_json() + assert msg["type"] == "audio.start" + assert msg["sentence_index"] == 1 + ws.receive_bytes() # audio data + msg = ws.receive_json() + assert msg["type"] == "audio.done" + assert msg["sentence_index"] == 1 + + ws.send_json({"type": "input.done"}) + msg = ws.receive_json() + assert msg["type"] == "session.done" + + +class TestStreamingSpeechSessionConfig: + """Test session config validation.""" + + def test_valid_config_all_fields(self): + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + "task_type": "CustomVoice", + "language": "English", + "instructions": "Speak cheerfully", + "response_format": "mp3", + "speed": 1.5, + "max_new_tokens": 1024, + } + ) + + ws.send_json({"type": "input.done"}) + msg = ws.receive_json() + assert msg["type"] == "session.done" + + def test_invalid_speed(self): + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "speed": 10.0, # invalid: > 4.0 + } + ) + + msg = ws.receive_json() + assert msg["type"] == "error" + assert "Invalid session config" in msg["message"] + + def test_invalid_response_format(self): + app, _, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "response_format": "invalid", + } + ) + + msg = ws.receive_json() + assert msg["type"] == "error" + + +class TestStreamingSpeechConfigModel: + """Unit tests for StreamingSpeechSessionConfig model.""" + + def test_defaults(self): + config = StreamingSpeechSessionConfig() + assert config.response_format == "wav" + assert config.speed == 1.0 + assert config.voice is None + assert config.task_type is None + + def test_all_fields(self): + config = StreamingSpeechSessionConfig( + model="test", + voice="Vivian", + task_type="CustomVoice", + language="English", + instructions="test", + response_format="mp3", + speed=2.0, + max_new_tokens=1024, + ref_audio="http://example.com/audio.wav", + ref_text="hello", + x_vector_only_mode=True, + ) + assert config.voice == "Vivian" + assert config.speed == 2.0 + assert config.response_format == "mp3" diff --git a/tests/entrypoints/openai_api/test_text_splitter.py b/tests/entrypoints/openai_api/test_text_splitter.py new file mode 100644 index 0000000000..a4e303b3b8 --- /dev/null +++ b/tests/entrypoints/openai_api/test_text_splitter.py @@ -0,0 +1,227 @@ +"""Tests for SentenceSplitter used in streaming TTS input.""" + +from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter + + +class TestSentenceSplitterEnglish: + """Tests for English sentence splitting.""" + + def test_single_sentence_no_boundary(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world") + assert result == [] + assert splitter.buffer == "Hello world" + + def test_single_sentence_with_boundary(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world. How are you?") + assert len(result) == 1 + assert result[0] == "Hello world." + + def test_multiple_sentences(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello. How are you? I am fine! ") + assert len(result) == 3 + assert result[0] == "Hello." + assert result[1] == "How are you?" + assert result[2] == "I am fine!" + + def test_exclamation_mark(self): + splitter = SentenceSplitter() + result = splitter.add_text("Wow, that is great! Tell me more.") + assert len(result) == 1 + assert result[0] == "Wow, that is great!" + + def test_question_mark(self): + splitter = SentenceSplitter() + result = splitter.add_text("Can you hear me? I hope so.") + assert len(result) == 1 + assert result[0] == "Can you hear me?" + + +class TestSentenceSplitterChinese: + """Tests for CJK sentence splitting.""" + + def test_chinese_period(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好世界。你好吗") + assert len(result) == 1 + assert result[0] == "你好世界。" + + def test_chinese_exclamation(self): + splitter = SentenceSplitter() + result = splitter.add_text("太好了!谢谢你") + assert len(result) == 1 + assert result[0] == "太好了!" + + def test_chinese_question(self): + splitter = SentenceSplitter() + result = splitter.add_text("你是谁?我是小明") + assert len(result) == 1 + assert result[0] == "你是谁?" + + def test_chinese_comma(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好,世界") + assert len(result) == 1 + assert result[0] == "你好," + + def test_chinese_semicolon(self): + splitter = SentenceSplitter() + result = splitter.add_text("第一点;第二点") + assert len(result) == 1 + assert result[0] == "第一点;" + + def test_chinese_multiple(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好!你好吗?我很好。") + assert len(result) == 3 + assert result[0] == "你好!" + assert result[1] == "你好吗?" + assert result[2] == "我很好。" + + +class TestSentenceSplitterMixed: + """Tests for mixed-language sentence splitting.""" + + def test_mixed_english_chinese(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello世界。How are you? ") + assert len(result) == 2 + assert result[0] == "Hello世界。" + assert result[1] == "How are you?" + + +class TestSentenceSplitterIncremental: + """Tests for incremental (multi-chunk) text input.""" + + def test_accumulation_across_chunks(self): + splitter = SentenceSplitter() + # First chunk: no boundary + result1 = splitter.add_text("Hello ") + assert result1 == [] + + # Second chunk: completes a sentence + result2 = splitter.add_text("world. How") + assert len(result2) == 1 + assert result2[0] == "Hello world." + assert splitter.buffer == "How" + + def test_word_by_word(self): + splitter = SentenceSplitter() + words = ["Hello, ", "how ", "are ", "you? ", "I ", "am ", "fine."] + all_sentences = [] + for word in words: + all_sentences.extend(splitter.add_text(word)) + + assert len(all_sentences) == 1 + assert all_sentences[0] == "Hello, how are you?" + # "I am fine." stays in buffer (no trailing whitespace after period) + + def test_three_chunks(self): + splitter = SentenceSplitter() + splitter.add_text("The quick brown ") + splitter.add_text("fox jumps. ") + result = splitter.add_text("Over the lazy dog. ") + # "The quick brown fox jumps." should have been returned on second chunk + # "Over the lazy dog." on third chunk + assert len(result) == 1 + assert result[0] == "Over the lazy dog." + + +class TestSentenceSplitterFlush: + """Tests for flush behavior.""" + + def test_flush_returns_remaining(self): + splitter = SentenceSplitter() + splitter.add_text("Hello world") + result = splitter.flush() + assert result == "Hello world" + assert splitter.buffer == "" + + def test_flush_empty_buffer(self): + splitter = SentenceSplitter() + result = splitter.flush() + assert result is None + + def test_flush_after_sentence(self): + splitter = SentenceSplitter() + splitter.add_text("Hello world. Remaining text") + result = splitter.flush() + assert result == "Remaining text" + + def test_flush_whitespace_only(self): + splitter = SentenceSplitter() + splitter.add_text("Hello. ") + # "Hello." extracted, buffer is " " + result = splitter.flush() + # Whitespace-only should return None + assert result is None + + def test_flush_clears_buffer(self): + splitter = SentenceSplitter() + splitter.add_text("some text") + splitter.flush() + assert splitter.buffer == "" + # Second flush should return None + assert splitter.flush() is None + + +class TestSentenceSplitterEdgeCases: + """Edge case tests.""" + + def test_empty_input(self): + splitter = SentenceSplitter() + result = splitter.add_text("") + assert result == [] + assert splitter.buffer == "" + + def test_none_like_empty(self): + """Empty string should not affect buffer.""" + splitter = SentenceSplitter() + splitter.add_text("Hello") + splitter.add_text("") + assert splitter.buffer == "Hello" + + def test_only_punctuation(self): + splitter = SentenceSplitter() + result = splitter.add_text(". ") + # "." is 1 char, below default min_sentence_length of 2 + # It will be carried forward + assert result == [] + + def test_min_sentence_length(self): + splitter = SentenceSplitter(min_sentence_length=10) + result = splitter.add_text("Hi. Hello world. ") + # "Hi." is 3 chars (< 10), so it gets carried to "Hello world." + assert len(result) == 1 + assert "Hi." in result[0] + assert "Hello world." in result[0] + + def test_min_sentence_length_zero(self): + splitter = SentenceSplitter(min_sentence_length=0) + result = splitter.add_text("A. B. ") + assert len(result) == 2 + + def test_no_boundary_then_flush(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world how are you") + assert result == [] + flushed = splitter.flush() + assert flushed == "Hello world how are you" + + def test_consecutive_punctuation(self): + splitter = SentenceSplitter() + result = splitter.add_text("Really?! Yes, really. ") + assert len(result) >= 1 + + def test_reuse_after_flush(self): + """Splitter can be reused after flush.""" + splitter = SentenceSplitter() + splitter.add_text("First session.") + splitter.flush() + + result = splitter.add_text("Second session. More text") + assert len(result) == 1 + assert result[0] == "Second session." + assert splitter.buffer == "More text" diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index fb52c7e464..2f8ccf95c2 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -19,7 +19,7 @@ import httpx import vllm.envs as envs -from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile +from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, WebSocket from fastapi.responses import JSONResponse, StreamingResponse from PIL import Image from starlette.datastructures import State @@ -87,6 +87,7 @@ ) from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt from vllm_omni.lora.request import LoRARequest from vllm_omni.lora.utils import stable_lora_int_id @@ -684,6 +685,10 @@ async def omni_init_app_state( engine_client, state.openai_serving_models, request_logger=request_logger ) + state.openai_streaming_speech = OmniStreamingSpeechHandler( + speech_service=state.openai_serving_speech, + ) + state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -819,6 +824,27 @@ async def list_voices(raw_request: Request): return JSONResponse(content={"voices": speakers}) +@router.websocket("/v1/audio/speech/stream") +async def streaming_speech(websocket: WebSocket): + """WebSocket endpoint for streaming text input TTS. + + Accepts text incrementally, splits at sentence boundaries, and + returns audio per sentence. See serving_speech_stream.py for protocol. + """ + handler = getattr(websocket.app.state, "openai_streaming_speech", None) + if handler is None: + await websocket.accept() + await websocket.send_json( + { + "type": "error", + "message": "Streaming speech is not available", + } + ) + await websocket.close() + return + await handler.handle_session(websocket) + + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index d23460626b..9c8a1a58a3 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -72,3 +72,19 @@ class Config: class AudioResponse(BaseModel): audio_data: bytes | str media_type: str + + +class StreamingSpeechSessionConfig(BaseModel): + """Configuration sent as the first WebSocket message for streaming TTS.""" + + model: str | None = None + voice: str | None = None + task_type: Literal["CustomVoice", "VoiceDesign", "Base"] | None = None + language: str | None = None + instructions: str | None = None + response_format: Literal["wav", "pcm", "flac", "mp3", "aac", "opus"] = "wav" + speed: float | None = Field(default=1.0, ge=0.25, le=4.0) + max_new_tokens: int | None = None + ref_audio: str | None = None + ref_text: str | None = None + x_vector_only_mode: bool | None = None diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index a8bae9e993..617a19ff0b 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -180,6 +180,109 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any return params + async def _generate_audio_bytes( + self, + request: OpenAICreateSpeechRequest, + ) -> tuple[bytes, str]: + """Core TTS generation logic: validate, generate, and encode audio. + + Extracted from create_speech() so it can be reused by the streaming + WebSocket handler for per-sentence generation. + + Args: + request: The speech request with text and parameters. + + Returns: + Tuple of (audio_bytes, media_type). + + Raises: + ValueError: If validation fails or generation produces no output. + """ + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Validate TTS parameters + if self._is_tts_model(): + validation_error = self._validate_tts_request(request) + if validation_error: + raise ValueError(validation_error) + + tts_params = self._build_tts_params(request) + prompt_text = self._build_tts_prompt(request.input) + prompt = { + "prompt": prompt_text, + "additional_information": tts_params, + } + else: + tts_params = {} + prompt = {"prompt": request.input} + + request_id = f"speech-{random_uuid()}" + + logger.info( + "TTS speech request %s: text=%r, task_type=%s", + request_id, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + tts_params.get("task_type", ["unknown"])[0], + ) + + sampling_params_list = self.engine_client.default_sampling_params_list + + generator = self.engine_client.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["audio"], + ) + + final_output: OmniRequestOutput | None = None + async for res in generator: + final_output = res + + if final_output is None: + raise ValueError("No output generated from the model.") + + # Extract audio from output + audio_output = None + if hasattr(final_output, "multimodal_output") and final_output.multimodal_output: + audio_output = final_output.multimodal_output + if not audio_output and hasattr(final_output, "request_output"): + if final_output.request_output and hasattr(final_output.request_output, "multimodal_output"): + audio_output = final_output.request_output.multimodal_output + + audio_key = None + if audio_output: + if "audio" in audio_output: + audio_key = "audio" + elif "model_outputs" in audio_output: + audio_key = "model_outputs" + + if not audio_output or audio_key is None: + raise ValueError("TTS model did not produce audio output.") + + audio_tensor = audio_output[audio_key] + sample_rate = audio_output.get("sr", 24000) + if hasattr(sample_rate, "item"): + sample_rate = sample_rate.item() + + if hasattr(audio_tensor, "float"): + audio_tensor = audio_tensor.float().detach().cpu().numpy() + + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.squeeze() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=int(sample_rate), + response_format=request.response_format or "wav", + speed=request.speed or 1.0, + stream_format=request.stream_format, + base64_encode=False, + ) + + audio_response: AudioResponse = self.create_audio(audio_obj) + return audio_response.audio_data, audio_response.media_type + async def create_speech( self, request: OpenAICreateSpeechRequest, @@ -209,98 +312,9 @@ async def create_speech( logger.error("Error with model %s", error_check_ret) return error_check_ret - if self.engine_client.errored: - raise self.engine_client.dead_error - - request_id = f"speech-{random_uuid()}" - try: - if self._is_tts_model(): - # Validate TTS parameters - validation_error = self._validate_tts_request(request) - if validation_error: - return self.create_error_response(validation_error) - - # Build TTS parameters and prompt - tts_params = self._build_tts_params(request) - prompt_text = self._build_tts_prompt(request.input) - prompt = { - "prompt": prompt_text, - "additional_information": tts_params, - } - else: - # Fallback for unsupported models - tts_params = {} - prompt = {"prompt": request.input} - - logger.info( - "TTS speech request %s: text=%r, task_type=%s", - request_id, - request.input[:50] + "..." if len(request.input) > 50 else request.input, - tts_params.get("task_type", ["unknown"])[0], - ) - - sampling_params_list = self.engine_client.default_sampling_params_list - - generator = self.engine_client.generate( - prompt=prompt, - request_id=request_id, - sampling_params_list=sampling_params_list, - output_modalities=["audio"], - ) - - final_output: OmniRequestOutput | None = None - async for res in generator: - final_output = res - - if final_output is None: - return self.create_error_response("No output generated from the model.") - - # Extract audio from output - # Audio can be in final_output.multimodal_output or final_output.request_output.multimodal_output - # Support both "audio" and "model_outputs" keys for compatibility with different models - audio_output = None - if hasattr(final_output, "multimodal_output") and final_output.multimodal_output: - audio_output = final_output.multimodal_output - if not audio_output and hasattr(final_output, "request_output"): - if final_output.request_output and hasattr(final_output.request_output, "multimodal_output"): - audio_output = final_output.request_output.multimodal_output - - # Check for audio data using either "audio" or "model_outputs" key - audio_key = None - if audio_output: - if "audio" in audio_output: - audio_key = "audio" - elif "model_outputs" in audio_output: - audio_key = "model_outputs" - - if not audio_output or audio_key is None: - return self.create_error_response("TTS model did not produce audio output.") - - audio_tensor = audio_output[audio_key] - sample_rate = audio_output.get("sr", 24000) - if hasattr(sample_rate, "item"): - sample_rate = sample_rate.item() - - # Convert tensor to numpy - if hasattr(audio_tensor, "float"): - audio_tensor = audio_tensor.float().detach().cpu().numpy() - - # Squeeze batch dimension if present, but preserve channel dimension for stereo - if audio_tensor.ndim > 1: - audio_tensor = audio_tensor.squeeze() - - audio_obj = CreateAudio( - audio_tensor=audio_tensor, - sample_rate=int(sample_rate), - response_format=request.response_format or "wav", - speed=request.speed or 1.0, - stream_format=request.stream_format, - base64_encode=False, - ) - - audio_response: AudioResponse = self.create_audio(audio_obj) - return Response(content=audio_response.audio_data, media_type=audio_response.media_type) + audio_data, media_type = await self._generate_audio_bytes(request) + return Response(content=audio_data, media_type=media_type) except asyncio.CancelledError: return self.create_error_response("Client disconnected") diff --git a/vllm_omni/entrypoints/openai/serving_speech_stream.py b/vllm_omni/entrypoints/openai/serving_speech_stream.py new file mode 100644 index 0000000000..708ea6d5ac --- /dev/null +++ b/vllm_omni/entrypoints/openai/serving_speech_stream.py @@ -0,0 +1,234 @@ +"""WebSocket handler for streaming text input TTS. + +Accepts text incrementally via WebSocket, buffers and splits at sentence +boundaries, and generates audio per sentence using the existing TTS pipeline. + +Protocol: + Client -> Server: + {"type": "session.config", ...} # Session config (sent once first) + {"type": "input.text", "text": "..."} # Text chunks + {"type": "input.done"} # End of input + + Server -> Client: + {"type": "audio.start", "sentence_index": 0, "sentence_text": "...", "format": "wav"} + + {"type": "audio.done", "sentence_index": 0} + {"type": "session.done", "total_sentences": N} + {"type": "error", "message": "..."} +""" + +import asyncio +import json + +from fastapi import WebSocket, WebSocketDisconnect +from pydantic import ValidationError +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.protocol.audio import ( + OpenAICreateSpeechRequest, + StreamingSpeechSessionConfig, +) +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter + +logger = init_logger(__name__) + +_DEFAULT_IDLE_TIMEOUT = 30.0 # seconds +_DEFAULT_CONFIG_TIMEOUT = 10.0 # seconds + + +class OmniStreamingSpeechHandler: + """Handles WebSocket sessions for streaming text-input TTS. + + Each WebSocket connection is an independent session. Text arrives + incrementally, is split at sentence boundaries, and audio is generated + per sentence using the existing OmniOpenAIServingSpeech pipeline. + + Args: + speech_service: The existing TTS serving instance (reused for + validation and audio generation). + idle_timeout: Max seconds to wait for a message before closing. + config_timeout: Max seconds to wait for the initial session.config. + """ + + def __init__( + self, + speech_service: OmniOpenAIServingSpeech, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + config_timeout: float = _DEFAULT_CONFIG_TIMEOUT, + ) -> None: + self._speech_service = speech_service + self._idle_timeout = idle_timeout + self._config_timeout = config_timeout + + async def handle_session(self, websocket: WebSocket) -> None: + """Main session loop for a single WebSocket connection.""" + await websocket.accept() + + try: + # 1. Wait for session.config + config = await self._receive_config(websocket) + if config is None: + return # Error already sent, connection closing + + splitter = SentenceSplitter() + sentence_index = 0 + + # 2. Receive text chunks until input.done + while True: + try: + raw = await asyncio.wait_for( + websocket.receive_text(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + await self._send_error(websocket, "Idle timeout: no message received") + return + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await self._send_error(websocket, "Invalid JSON message") + continue + + msg_type = msg.get("type") + + if msg_type == "input.text": + text = msg.get("text", "") + sentences = splitter.add_text(text) + for sentence in sentences: + await self._generate_and_send(websocket, config, sentence, sentence_index) + sentence_index += 1 + + elif msg_type == "input.done": + # Flush remaining buffer + remaining = splitter.flush() + if remaining: + await self._generate_and_send(websocket, config, remaining, sentence_index) + sentence_index += 1 + + # Send session.done + await websocket.send_json( + { + "type": "session.done", + "total_sentences": sentence_index, + } + ) + return + + else: + await self._send_error( + websocket, + f"Unknown message type: {msg_type}", + ) + + except WebSocketDisconnect: + logger.info("Streaming speech: client disconnected") + except Exception as e: + logger.exception("Streaming speech session error: %s", e) + try: + await self._send_error(websocket, f"Internal error: {e}") + except Exception: + pass + + async def _receive_config(self, websocket: WebSocket) -> StreamingSpeechSessionConfig | None: + """Wait for and validate the session.config message.""" + try: + raw = await asyncio.wait_for( + websocket.receive_text(), + timeout=self._config_timeout, + ) + except asyncio.TimeoutError: + await self._send_error(websocket, "Timeout waiting for session.config") + return None + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await self._send_error(websocket, "Invalid JSON in session.config") + return None + + if msg.get("type") != "session.config": + await self._send_error( + websocket, + f"Expected session.config, got: {msg.get('type')}", + ) + return None + + try: + config = StreamingSpeechSessionConfig(**{k: v for k, v in msg.items() if k != "type"}) + except ValidationError as e: + await self._send_error(websocket, f"Invalid session config: {e}") + return None + + return config + + async def _generate_and_send( + self, + websocket: WebSocket, + config: StreamingSpeechSessionConfig, + sentence_text: str, + sentence_index: int, + ) -> None: + """Generate audio for a single sentence and send it over WebSocket.""" + response_format = config.response_format or "wav" + + # Send audio.start + await websocket.send_json( + { + "type": "audio.start", + "sentence_index": sentence_index, + "sentence_text": sentence_text, + "format": response_format, + } + ) + + try: + # Build a per-sentence request reusing session config + request = OpenAICreateSpeechRequest( + input=sentence_text, + model=config.model, + voice=config.voice, + task_type=config.task_type, + language=config.language, + instructions=config.instructions, + response_format=response_format, + speed=config.speed, + max_new_tokens=config.max_new_tokens, + ref_audio=config.ref_audio, + ref_text=config.ref_text, + x_vector_only_mode=config.x_vector_only_mode, + ) + + audio_bytes, _ = await self._speech_service._generate_audio_bytes(request) + + # Send binary audio frame + await websocket.send_bytes(audio_bytes) + + except Exception as e: + logger.error("Generation failed for sentence %d: %s", sentence_index, e) + await self._send_error( + websocket, + f"Generation failed for sentence {sentence_index}: {e}", + ) + + # Send audio.done (even on error, so client can track progress) + await websocket.send_json( + { + "type": "audio.done", + "sentence_index": sentence_index, + } + ) + + @staticmethod + async def _send_error(websocket: WebSocket, message: str) -> None: + """Send an error message to the client.""" + try: + await websocket.send_json( + { + "type": "error", + "message": message, + } + ) + except Exception: + pass # Connection may already be closed diff --git a/vllm_omni/entrypoints/openai/text_splitter.py b/vllm_omni/entrypoints/openai/text_splitter.py new file mode 100644 index 0000000000..27d30f417e --- /dev/null +++ b/vllm_omni/entrypoints/openai/text_splitter.py @@ -0,0 +1,90 @@ +"""Multi-language sentence boundary detector for streaming TTS input. + +Buffers incoming text and splits at sentence boundaries (English and CJK), +yielding complete sentences for audio generation. +""" + +import re + +# Sentence boundary pattern: +# - English: .!? followed by whitespace or end of string +# - CJK fullwidth: 。!?,; +_SENTENCE_BOUNDARY_RE = re.compile( + r"(?<=[.!?])\s+" # English punctuation followed by whitespace + r"|(?<=[。!?,;])" # CJK fullwidth punctuation +) + + +class SentenceSplitter: + """Incremental sentence splitter for streaming text input. + + Buffers text and yields complete sentences when boundaries are detected. + Designed for TTS pipelines where text arrives incrementally (e.g., from STT). + + Args: + min_sentence_length: Minimum character length for a sentence. + Sentences shorter than this are kept in the buffer to avoid + splitting on abbreviations like "Dr." or "U.S.". + """ + + def __init__(self, min_sentence_length: int = 2) -> None: + self._buffer: str = "" + self._min_sentence_length = min_sentence_length + + @property + def buffer(self) -> str: + """Current buffered text.""" + return self._buffer + + def add_text(self, text: str) -> list[str]: + """Add text to the buffer and return any complete sentences. + + Args: + text: Incoming text chunk. + + Returns: + List of complete sentences extracted from the buffer. + May be empty if no sentence boundary was found. + """ + if not text: + return [] + + self._buffer += text + return self._extract_sentences() + + def flush(self) -> str | None: + """Flush remaining buffered text as a final sentence. + + Returns: + The remaining buffered text (stripped), or None if buffer is empty. + """ + remaining = self._buffer.strip() + self._buffer = "" + return remaining if remaining else None + + def _extract_sentences(self) -> list[str]: + """Split buffer at sentence boundaries, keeping incomplete text buffered.""" + parts = _SENTENCE_BOUNDARY_RE.split(self._buffer) + + if len(parts) <= 1: + # No boundary found — keep everything in buffer + return [] + + sentences: list[str] = [] + carry = "" + # All parts except the last are complete sentences + for i in range(len(parts) - 1): + text = carry + parts[i] + carry = "" + stripped = text.strip() + if len(stripped) >= self._min_sentence_length: + sentences.append(stripped) + elif stripped: + # Too short (e.g. "Dr.") — carry forward to next part + carry = text + # else: empty, skip + + # Last part stays in buffer (may be incomplete) + self._buffer = carry + parts[-1] + + return sentences