Skip to content
255 changes: 255 additions & 0 deletions tests/e2e/online_serving/test_qwen3_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E Online tests for Qwen3-TTS model with text input and audio output.

These tests verify the /v1/audio/speech endpoint works correctly with
actual model inference, not mocks.
"""

import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"

import threading
from pathlib import Path

import httpx
import pytest

from tests.conftest import OmniServer

# Model variants for different TTS tasks (using 0.6B for faster CI)
models = {
"CustomVoice": "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
"VoiceDesign": "Qwen/Qwen3-TTS-12Hz-0.6B-VoiceDesign",
}


def get_stage_config():
"""Get the stage config path for Qwen3-TTS."""
return str(
Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml"
)


_omni_server_lock = threading.Lock()


@pytest.fixture(scope="module")
def omni_server_customvoice(request):
"""Start vLLM-Omni server with CustomVoice model."""
with _omni_server_lock:
model = models["CustomVoice"]
stage_config_path = get_stage_config()

print(f"Starting OmniServer with model: {model}")

with OmniServer(
model,
[
"--stage-configs-path",
stage_config_path,
"--stage-init-timeout",
"120",
"--trust-remote-code",
"--enforce-eager",
],
) as server:
print("OmniServer started successfully")
yield server
print("OmniServer stopping...")

print("OmniServer stopped")


@pytest.fixture(scope="module")
def omni_server_voicedesign(request):
"""Start vLLM-Omni server with VoiceDesign model."""
with _omni_server_lock:
model = models["VoiceDesign"]
stage_config_path = get_stage_config()

print(f"Starting OmniServer with model: {model}")

with OmniServer(
model,
[
"--stage-configs-path",
stage_config_path,
"--stage-init-timeout",
"120",
"--trust-remote-code",
"--enforce-eager",
],
) as server:
print("OmniServer started successfully")
yield server
print("OmniServer stopping...")

print("OmniServer stopped")


def make_speech_request(
host: str,
port: int,
text: str,
voice: str = "vivian",
language: str = "English",
task_type: str | None = None,
instructions: str | None = None,
timeout: float = 300.0,
) -> httpx.Response:
"""Make a request to the /v1/audio/speech endpoint."""
url = f"http://{host}:{port}/v1/audio/speech"
payload = {
"input": text,
"voice": voice,
"language": language,
}
if task_type:
payload["task_type"] = task_type
if instructions:
payload["instructions"] = instructions

with httpx.Client(timeout=timeout) as client:
return client.post(url, json=payload)


def verify_wav_audio(content: bytes) -> bool:
"""Verify that content is valid WAV audio data."""
# WAV files start with "RIFF" header
if len(content) < 44: # Minimum WAV header size
return False
return content[:4] == b"RIFF" and content[8:12] == b"WAVE"


class TestQwen3TTSCustomVoice:
"""E2E tests for Qwen3-TTS CustomVoice model."""

def test_speech_english_basic(self, omni_server_customvoice) -> None:
"""Test basic English TTS generation."""
response = make_speech_request(
host=omni_server_customvoice.host,
port=omni_server_customvoice.port,
text="Hello, how are you?",
voice="vivian",
language="English",
)

assert response.status_code == 200, f"Request failed: {response.text}"
assert response.headers.get("content-type") == "audio/wav"
assert verify_wav_audio(response.content), "Response is not valid WAV audio"
assert len(response.content) > 1000, "Audio content too small"

def test_speech_chinese_basic(self, omni_server_customvoice) -> None:
"""Test basic Chinese TTS generation."""
response = make_speech_request(
host=omni_server_customvoice.host,
port=omni_server_customvoice.port,
text="你好,我是通义千问",
voice="vivian",
language="Chinese",
)

assert response.status_code == 200, f"Request failed: {response.text}"
assert response.headers.get("content-type") == "audio/wav"
assert verify_wav_audio(response.content), "Response is not valid WAV audio"
assert len(response.content) > 1000, "Audio content too small"

def test_speech_different_voices(self, omni_server_customvoice) -> None:
"""Test TTS with different voice options."""
voices = ["vivian", "ryan"]
for voice in voices:
response = make_speech_request(
host=omni_server_customvoice.host,
port=omni_server_customvoice.port,
text="Testing voice selection.",
voice=voice,
language="English",
)

assert response.status_code == 200, f"Request failed for voice {voice}: {response.text}"
assert verify_wav_audio(response.content), f"Invalid WAV for voice {voice}"

def test_speech_binary_response_not_utf8_error(self, omni_server_customvoice) -> None:
"""
Regression test: Verify binary audio is returned, not UTF-8 error.

This test ensures the multimodal_output property correctly retrieves
audio from completion outputs, preventing the "TTS model did not
produce audio output" error.
"""
response = make_speech_request(
host=omni_server_customvoice.host,
port=omni_server_customvoice.port,
text="This should return binary audio, not a JSON error.",
voice="vivian",
language="English",
)

# Should NOT be a JSON error response
assert response.status_code == 200, f"Request failed: {response.text}"

# Verify it's binary audio, not JSON
try:
# If this succeeds and starts with {"error", it's a bug
text = response.content.decode("utf-8")
assert not text.startswith('{"error"'), f"Got error response instead of audio: {text}"
except UnicodeDecodeError:
# This is expected - binary audio can't be decoded as UTF-8
pass

assert verify_wav_audio(response.content), "Response is not valid WAV audio"


class TestQwen3TTSVoiceDesign:
"""E2E tests for Qwen3-TTS VoiceDesign model."""

def test_speech_with_voice_description(self, omni_server_voicedesign) -> None:
"""Test TTS with natural language voice description."""
response = make_speech_request(
host=omni_server_voicedesign.host,
port=omni_server_voicedesign.port,
text="Hello, this is a test.",
task_type="VoiceDesign",
instructions="A warm, friendly female voice with a gentle tone",
)

assert response.status_code == 200, f"Request failed: {response.text}"
assert response.headers.get("content-type") == "audio/wav"
assert verify_wav_audio(response.content), "Response is not valid WAV audio"
assert len(response.content) > 1000, "Audio content too small"


class TestQwen3TTSAPIEndpoints:
"""Test API endpoint functionality."""

def test_list_voices_endpoint(self, omni_server_customvoice) -> None:
"""Test the /v1/audio/voices endpoint returns available voices."""
url = f"http://{omni_server_customvoice.host}:{omni_server_customvoice.port}/v1/audio/voices"

with httpx.Client(timeout=30.0) as client:
response = client.get(url)

assert response.status_code == 200
data = response.json()
assert "voices" in data
assert isinstance(data["voices"], list)
assert len(data["voices"]) > 0
# Check some expected voices are present
voices_lower = [v.lower() for v in data["voices"]]
assert "vivian" in voices_lower or "ryan" in voices_lower

def test_models_endpoint(self, omni_server_customvoice) -> None:
"""Test the /v1/models endpoint returns loaded model."""
url = f"http://{omni_server_customvoice.host}:{omni_server_customvoice.port}/v1/models"

with httpx.Client(timeout=30.0) as client:
response = client.get(url)

assert response.status_code == 200
data = response.json()
assert "data" in data
assert len(data["data"]) > 0
Loading