diff --git a/examples/online_serving/stable_audio/README.md b/examples/online_serving/stable_audio/README.md new file mode 100644 index 0000000000..a390dc61e3 --- /dev/null +++ b/examples/online_serving/stable_audio/README.md @@ -0,0 +1,234 @@ +# Stable Audio Online Serving + +Generate audio from text prompts using Stable Audio models via an OpenAI-compatible API endpoint. + +## Features + +- **OpenAI-compatible API**: Use `/v1/audio/speech` endpoint +- **Flexible control**: Adjust audio length, guidance scale, inference steps +- **Quality control**: Use negative prompts to avoid unwanted characteristics +- **Reproducible**: Set random seed for deterministic generation + +## Quick Start + +### 1. Start the Server + +```bash +vllm-omni serve stabilityai/stable-audio-open-1.0 \ + --host 0.0.0.0 \ + --port 8000 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --enforce-eager \ + --omni +``` + +### 2. Generate Audio + +#### Using curl + +```bash +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of a cat purring", + "audio_length": 10.0 + }' --output cat.wav +``` + +#### Using Python Client + +```bash +python stable_audio_client.py \ + --text "The sound of a cat purring" \ + --audio_length 10.0 \ + --output cat.wav +``` + +#### Using Bash Script + +```bash +bash curl_examples.sh +``` + +## API Reference + +### Endpoint + +``` +POST /v1/audio/speech +``` + +### Request Body + +```json +{ + "input": "Text description of the audio", + "audio_length": 10.0, + "audio_start": 0.0, + "negative_prompt": "Low quality", + "guidance_scale": 7.0, + "num_inference_steps": 100, + "seed": 42, + "response_format": "wav" +} +``` + +### Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `input` | string | **required** | Text prompt describing the audio to generate | +| `audio_length` | float | 10.0 | Audio duration in seconds (max ~47s for stable-audio-open-1.0) | +| `audio_start` | float | 0.0 | Audio start time in seconds | +| `negative_prompt` | string | null | Text describing what to avoid in generation | +| `guidance_scale` | float | 7.0 | Classifier-free guidance scale (higher = more adherence to prompt) | +| `num_inference_steps` | int | 100 | Number of denoising steps (higher = better quality, slower) | +| `seed` | int | null | Random seed for reproducibility | +| `response_format` | string | "wav" | Output format: wav, mp3, flac, pcm | + +### Response + +Returns audio data in the requested format (default: WAV). + +## Usage Examples + +### Basic Generation + +```bash +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of ocean waves" + }' --output ocean.wav +``` + +### Custom Duration + +```bash +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A dog barking", + "audio_length": 5.0 + }' --output dog_5s.wav +``` + +### High Quality with Negative Prompt + +```bash +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A piano playing a gentle melody", + "audio_length": 10.0, + "negative_prompt": "Low quality, distorted, noisy", + "guidance_scale": 8.0, + "num_inference_steps": 150 + }' --output piano_hq.wav +``` + +### Reproducible Generation + +```bash +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Thunder and rain sounds", + "audio_length": 15.0, + "seed": 42 + }' --output thunder.wav +``` + +### Quick Generation (Fewer Steps) + +For faster generation with slightly lower quality: + +```bash +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Birds chirping in a forest", + "audio_length": 8.0, + "num_inference_steps": 50 + }' --output birds_quick.wav +``` + +## Python Client Examples + +### Simple Generation + +```bash +python stable_audio_client.py \ + --text "The sound of a cat purring" +``` + +### Custom Parameters + +```bash +python stable_audio_client.py \ + --text "Thunder and rain" \ + --audio_length 15.0 \ + --negative_prompt "Low quality" \ + --guidance_scale 7.0 \ + --num_inference_steps 100 \ + --seed 42 \ + --output thunder.wav +``` + +### Different Output Format + +```bash +python stable_audio_client.py \ + --text "Guitar playing" \ + --response_format mp3 \ + --output guitar.mp3 +``` + +## Tips + +1. **Audio Length**: Keep under 47 seconds for `stable-audio-open-1.0` +2. **Quality vs Speed**: + - 50 steps: Fast, decent quality + - 100 steps: Good balance (default) + - 150+ steps: High quality, slower +3. **Guidance Scale**: + - Lower (3-5): More creative/varied + - Default (7): Good balance + - Higher (10+): More literal to prompt +4. **Negative Prompts**: Use to avoid "Low quality", "distorted", "noisy", etc. +5. **Seeds**: Use same seed for reproducible results + +## Performance + +| Inference Steps | Quality | Speed | Use Case | +|----------------|---------|-------|----------| +| 50 | Good | Fast | Quick previews | +| 100 (default) | Very Good | Medium | Production | +| 150+ | Excellent | Slow | Final/critical audio | + +## Troubleshooting + +### Server not responding +- Check if server is running: `curl http://localhost:8000/health` +- Check server logs for errors + +### Audio quality issues +- Increase `num_inference_steps` (e.g., 150) +- Add negative prompts: `"Low quality, distorted, noisy"` +- Increase `guidance_scale` for more prompt adherence + +### Generation timeout +- Reduce `num_inference_steps` +- Reduce `audio_length` +- Check GPU memory with `nvidia-smi` + +### Wrong audio length +- Ensure `audio_length` is within model limits (~47s max) +- Adjust `audio_start` if trimming is needed + +## See Also + +- [Offline Inference Example](../../offline_inference/text_to_audio/README.md) +- [Stable Audio Model Card](https://huggingface.co/stabilityai/stable-audio-open-1.0) +- [vLLM-Omni Documentation](https://github.com/vllm-project/vllm-omni) diff --git a/examples/online_serving/stable_audio/curl_examples.sh b/examples/online_serving/stable_audio/curl_examples.sh new file mode 100755 index 0000000000..5763b7cbac --- /dev/null +++ b/examples/online_serving/stable_audio/curl_examples.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Examples for using Stable Audio with curl via /v1/audio/speech endpoint + +# Example 1: Simple request with default parameters +echo "Example 1: Simple request with default parameters" +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound audience clapping and cheering in a stadium" + }' --output stadium.wav + +# Example 2: Request with custom audio_length +echo "Example 2: Custom audio length (5 seconds)" +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The sound of a dog barking", + "audio_length": 5.0 + }' --output dog_5s.wav + +# Example 3: Request with negative prompt for quality control +echo "Example 3: With negative prompt" +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "A piano playing a gentle melody", + "audio_length": 10.0, + "negative_prompt": "Low quality, distorted, noisy" + }' --output piano.wav + +# Example 4: Full control with all parameters +echo "Example 4: Full control (custom length, guidance, steps, seed)" +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Thunder and rain sounds", + "audio_length": 15.0, + "negative_prompt": "Low quality", + "guidance_scale": 7.0, + "num_inference_steps": 100, + "seed": 42 + }' --output thunder_rain.wav + +# Example 5: Quick generation with fewer steps (faster but lower quality) +echo "Example 5: Quick generation (fewer steps)" +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Ocean waves crashing on a beach", + "audio_length": 8.0, + "num_inference_steps": 50 + }' --output ocean.wav + +echo "All examples completed!" diff --git a/examples/online_serving/stable_audio/stable_audio_client.py b/examples/online_serving/stable_audio/stable_audio_client.py new file mode 100755 index 0000000000..26495298e6 --- /dev/null +++ b/examples/online_serving/stable_audio/stable_audio_client.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +OpenAI-compatible client for Stable Audio via /v1/audio/speech endpoint. + +This script demonstrates how to use the OpenAI-compatible speech API +to generate audio from text using Stable Audio models. + +Examples: + # Simple generation + python stable_audio_client.py --text "The sound of a cat purring" + + # With custom duration + python stable_audio_client.py --text "A dog barking" --audio_length 5.0 + + # With all parameters + python stable_audio_client.py --text "Thunder and rain" \ + --audio_length 15.0 \ + --negative_prompt "Low quality" \ + --guidance_scale 7.0 \ + --num_inference_steps 100 \ + --seed 42 \ + --output thunder.wav +""" + +import argparse +import sys + +import requests + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate audio with Stable Audio via OpenAI-compatible API") + parser.add_argument( + "--api_url", + default="http://localhost:8000/v1/audio/speech", + help="API endpoint URL", + ) + parser.add_argument( + "--text", + default="The sound of a cat purring", + help="Text prompt for audio generation", + ) + parser.add_argument( + "--audio_length", + type=float, + default=10.0, + help="Audio length in seconds (max ~47s for stable-audio-open-1.0)", + ) + parser.add_argument( + "--audio_start", + type=float, + default=0.0, + help="Audio start time in seconds", + ) + parser.add_argument( + "--negative_prompt", + default="Low quality", + help="Negative prompt for classifier-free guidance", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7.0, + help="Guidance scale for diffusion (higher = more adherence to prompt)", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="Number of inference steps (higher = better quality, slower)", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducibility", + ) + parser.add_argument( + "--output", + default="stable_audio_output.wav", + help="Output file path", + ) + parser.add_argument( + "--response_format", + default="wav", + choices=["wav", "mp3", "flac", "pcm"], + help="Audio output format", + ) + return parser.parse_args() + + +def generate_audio(args): + """Generate audio using the API.""" + + # Build request payload + payload = { + "input": args.text, + "audio_length": args.audio_length, + "audio_start": args.audio_start, + "response_format": args.response_format, + } + + # Add optional parameters + if args.negative_prompt: + payload["negative_prompt"] = args.negative_prompt + if args.guidance_scale: + payload["guidance_scale"] = args.guidance_scale + if args.num_inference_steps: + payload["num_inference_steps"] = args.num_inference_steps + if args.seed is not None: + payload["seed"] = args.seed + + print(f"\n{'=' * 60}") + print("Stable Audio - Text-to-Audio Generation") + print(f"{'=' * 60}") + print(f"API URL: {args.api_url}") + print(f"Prompt: {args.text}") + print(f"Audio length: {args.audio_length}s") + print(f"Negative prompt: {args.negative_prompt}") + print(f"Guidance scale: {args.guidance_scale}") + print(f"Inference steps: {args.num_inference_steps}") + if args.seed is not None: + print(f"Seed: {args.seed}") + print(f"Output: {args.output}") + print(f"{'=' * 60}\n") + + try: + # Make the API request + print("Generating audio...") + response = requests.post( + args.api_url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=300, # 5 minute timeout for long generations + ) + + # Check for errors + if response.status_code != 200: + print(f"Error: API returned status code {response.status_code}") + print(f"Response: {response.text}") + return False + + # Save the audio + with open(args.output, "wb") as f: + f.write(response.content) + + print(f"✓ Audio saved to {args.output}") + print(f" File size: {len(response.content) / 1024:.1f} KB") + return True + + except requests.exceptions.Timeout: + print("Error: Request timed out. Try reducing inference steps or audio length.") + return False + except requests.exceptions.ConnectionError: + print(f"Error: Could not connect to {args.api_url}") + print("Make sure the server is running.") + return False + except Exception as e: + print(f"Error: {e}") + return False + + +def main(): + args = parse_args() + success = generate_audio(args) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py index c4884e3abd..1d4636eef8 100644 --- a/tests/entrypoints/test_omni_diffusion.py +++ b/tests/entrypoints/test_omni_diffusion.py @@ -447,7 +447,7 @@ def _mock_cached_file(path_or_repo_id, *args, **kwargs): def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config): """Test that stage configs are auto-loaded when stage_configs_path is None.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [ _FakeStageConfig(fake_stage_config), _FakeStageConfig(fake_stage_config), @@ -512,7 +512,7 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): """Test that generate raises ValueError when sampling_params_list length doesn't match.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys @@ -564,7 +564,7 @@ def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): stage_cfg1 = dict(fake_stage_config) stage_cfg1["processed_input"] = ["processed-for-stage-1"] - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -663,7 +663,7 @@ def test_generate_pipeline_with_batch_input(monkeypatch, fake_stage_config): stage_cfg1 = dict(fake_stage_config) stage_cfg0["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -770,7 +770,7 @@ def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): stage_cfg0["final_output"] = False stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -849,7 +849,7 @@ def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_confi stage_cfg0["final_output"] = False stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -917,7 +917,7 @@ def _fake_loader(model: str, base_engine_args=None): def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): """Test that _wait_for_stages_ready handles timeout correctly.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys @@ -973,7 +973,7 @@ def init_stage_worker(self, *args, **kwargs): def test_generate_handles_error_messages(monkeypatch, fake_stage_config): """Test that generate handles error messages from stages correctly.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys @@ -1048,7 +1048,7 @@ def _fake_loader(model: str, base_engine_args=None): def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config): """Test that close() sends shutdown signal to all input queues.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index f99c6d8336..1a5b55db65 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -447,7 +447,7 @@ def _mock_cached_file(path_or_repo_id, *args, **kwargs): def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config): """Test that stage configs are auto-loaded when stage_configs_path is None.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [ _FakeStageConfig(fake_stage_config), _FakeStageConfig(fake_stage_config), @@ -512,7 +512,7 @@ def _fake_loader(model: str, base_engine_args=None): def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config): """Test that generate raises ValueError when sampling_params_list length doesn't match.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys @@ -564,7 +564,7 @@ def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config): stage_cfg1 = dict(fake_stage_config) stage_cfg1["processed_input"] = ["processed-for-stage-1"] - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -664,7 +664,7 @@ def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config): stage_cfg0["final_output"] = False stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -743,7 +743,7 @@ def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_confi stage_cfg0["final_output"] = False stage_cfg1["final_output"] = False - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] import sys @@ -811,7 +811,7 @@ def _fake_loader(model: str, base_engine_args=None): def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config): """Test that _wait_for_stages_ready handles timeout correctly.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys @@ -867,7 +867,7 @@ def init_stage_worker(self, *args, **kwargs): def test_generate_handles_error_messages(monkeypatch, fake_stage_config): """Test that generate handles error messages from stages correctly.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys @@ -942,7 +942,7 @@ def _fake_loader(model: str, base_engine_args=None): def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config): """Test that close() sends shutdown signal to all input queues.""" - def _fake_loader(model: str, base_engine_args=None): + def _fake_loader(config_path: str, base_engine_args=None): return [_FakeStageConfig(fake_stage_config)] import sys diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 97357dc3b3..9998fec9b3 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -40,6 +40,7 @@ load_stage_configs_from_model, load_stage_configs_from_yaml, resolve_model_config_path, + resolve_model_type, ) from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -202,11 +203,14 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: tokenizer = kwargs.get("tokenizer", None) base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None + self.model_type = resolve_model_type(model) # Load stage configurations from YAML if stage_configs_path is None: - self.config_path = resolve_model_config_path(model) - self.stage_configs = load_stage_configs_from_model(model, base_engine_args=base_engine_args) + self.config_path = resolve_model_config_path(self.model_type) + self.stage_configs = load_stage_configs_from_model( + config_path=self.config_path, base_engine_args=base_engine_args + ) if not self.stage_configs: default_stage_cfg = self._create_default_diffusion_stage_cfg(kwargs) self.stage_configs = OmegaConf.create(default_stage_cfg) diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 0356ea2310..295b3f5d81 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -25,6 +25,7 @@ load_stage_configs_from_model, load_stage_configs_from_yaml, resolve_model_config_path, + resolve_model_type, ) logger = init_logger(__name__) @@ -82,11 +83,12 @@ def __init__( self.ray_address = kwargs.get("ray_address", None) self.batch_timeout = batch_timeout self._enable_stats: bool = bool(log_stats) + self.model_type = resolve_model_type(model) # Load stage configurations if stage_configs_path is None: - self.config_path = resolve_model_config_path(model) - self.stage_configs = load_stage_configs_from_model(model) + self.config_path = resolve_model_config_path(self.model_type) + self.stage_configs = load_stage_configs_from_model(config_path=self.config_path) else: self.config_path = stage_configs_path self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index fb52c7e464..45357e7f23 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -138,8 +138,30 @@ class _DiffusionServingModels: provide a lightweight fallback. """ + class _NullModelConfig: + def __getattr__(self, name): + return None + + class _Unsupported: + def __init__(self, name: str): + self.name = name + + def __call__(self, *args, **kwargs): + raise NotImplementedError(f"{self.name} is not supported in diffusion mode") + + def __getattr__(self, attr): + raise NotImplementedError(f"{self.name}.{attr} is not supported in diffusion mode") + def __init__(self, base_model_paths: list[BaseModelPath]) -> None: self._base_model_paths = base_model_paths + self.model_config = self._NullModelConfig() + + def __getattr__(self, name): + """ + Any attribute OpenAIServing tries to access but we don't explicitly define + will safely resolve to None. + """ + return self._Unsupported(name) async def show_available_models(self) -> ModelList: return ModelList( @@ -390,10 +412,10 @@ async def omni_init_app_state( # For omni models state.stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None + model_name = served_model_names[0] if served_model_names else args.model # Pure Diffusion mode: use simplified initialization logic if is_pure_diffusion: - model_name = served_model_names[0] if served_model_names else args.model state.vllm_config = None state.diffusion_engine = engine_client state.openai_serving_models = _DiffusionServingModels(base_model_paths) @@ -406,6 +428,13 @@ async def omni_init_app_state( model_name=model_name, ) + state.openai_serving_speech = OmniOpenAIServingSpeech.for_diffusion( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + model_name=model_name, + ) + state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 logger.info("Pure diffusion API server initialized for model: %s", model_name) @@ -681,7 +710,7 @@ async def omni_init_app_state( ) state.openai_serving_speech = OmniOpenAIServingSpeech( - engine_client, state.openai_serving_models, request_logger=request_logger + engine_client, state.openai_serving_models, request_logger=request_logger, model_name=model_name ) state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm_omni/entrypoints/openai/audio_utils_mixin.py b/vllm_omni/entrypoints/openai/audio_utils_mixin.py index 13df32ebe0..3023b67ad6 100644 --- a/vllm_omni/entrypoints/openai/audio_utils_mixin.py +++ b/vllm_omni/entrypoints/openai/audio_utils_mixin.py @@ -45,6 +45,10 @@ def create_audio(self, audio_obj: CreateAudio) -> AudioResponse: "Only mono (1D) and stereo (2D) are supported." ) + if audio_tensor.ndim == 2 and audio_tensor.shape[0] == 2: + # Convert from [channels, samples] to [samples, channels] + audio_tensor = audio_tensor.T + audio_tensor, sample_rate = self._apply_speed_adjustment(audio_tensor, speed, sample_rate) supported_formats = { diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index d23460626b..81ca052c14 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -49,6 +49,32 @@ class OpenAICreateSpeechRequest(BaseModel): description="Maximum tokens to generate", ) + # Stable Audio specific parameters + audio_length: float | None = Field( + default=None, + description="Audio length in seconds (for Stable Audio models)", + ) + audio_start: float | None = Field( + default=0.0, + description="Audio start time in seconds (for Stable Audio models)", + ) + negative_prompt: str | None = Field( + default=None, + description="Negative prompt for classifier-free guidance (for Stable Audio models)", + ) + guidance_scale: float | None = Field( + default=None, + description="Guidance scale for diffusion models (for Stable Audio models)", + ) + num_inference_steps: int | None = Field( + default=None, + description="Number of inference steps (for Stable Audio models)", + ) + seed: int | None = Field( + default=None, + description="Random seed for reproducibility (for Stable Audio models)", + ) + @field_validator("stream_format") @classmethod def validate_stream_format(cls, v: str) -> str: diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index a8bae9e993..d07105e0d9 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1,6 +1,7 @@ import asyncio from typing import Any +import torch from fastapi import Request from fastapi.responses import Response from vllm.entrypoints.openai.engine.serving import OpenAIServing @@ -13,6 +14,7 @@ CreateAudio, OpenAICreateSpeechRequest, ) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -39,11 +41,19 @@ class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin): def __init__(self, *args, **kwargs): + self.model_name = kwargs.pop("model_name", None) super().__init__(*args, **kwargs) + self.diffusion_mode = False # Load supported speakers self.supported_speakers = self._load_supported_speakers() logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") + @classmethod + def for_diffusion(cls, *args, **kwargs) -> bool: + """Check if the current instance is in diffusion mode.""" + cls.diffusion_mode = True + return cls(*args, **kwargs) + def _load_supported_speakers(self) -> set[str]: """Load supported speakers (case-insensitive) from the model configuration.""" try: @@ -72,6 +82,9 @@ def _is_tts_model(self) -> bool: return True return False + def _is_stable_audio_model(self) -> bool: + return self.engine_client.model_type == "StableAudioPipeline" + def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: """Validate TTS request parameters. Returns error message or None.""" task_type = request.task_type or "CustomVoice" @@ -215,6 +228,8 @@ async def create_speech( request_id = f"speech-{random_uuid()}" try: + sampling_params_list = self.engine_client.default_sampling_params_list + default_sr = 24000 # Default sample rate for TTS models if self._is_tts_model(): # Validate TTS parameters validation_error = self._validate_tts_request(request) @@ -228,6 +243,45 @@ async def create_speech( "prompt": prompt_text, "additional_information": tts_params, } + elif self._is_stable_audio_model(): + # Handle Stable Audio models + # Stable Audio uses diffusion, needs different parameters + default_sr = 44100 # Default sample rate for Stable Audio + + # Build prompt for Stable Audio + prompt = { + "prompt": request.input, + } + if request.negative_prompt: + prompt["negative_prompt"] = request.negative_prompt + + # Build sampling params for diffusion + sampling_params_list = [OmniDiffusionSamplingParams(num_outputs_per_prompt=1)] + + # Create generator if seed provided + if request.seed is not None: + from vllm_omni.platforms import current_omni_platform + + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(request.seed) + sampling_params_list[0].generator = generator + + if request.guidance_scale is not None: + sampling_params_list[0].guidance_scale = request.guidance_scale + + if request.num_inference_steps is not None: + sampling_params_list[0].num_inference_steps = request.num_inference_steps + + # Set up audio duration parameters + if request.audio_length is not None: + audio_length = request.audio_length + audio_start = request.audio_start if request.audio_start is not None else 0.0 + audio_end_in_s = audio_start + audio_length + sampling_params_list[0].extra_args = { + "audio_start_in_s": audio_start, + "audio_end_in_s": audio_end_in_s, + } + + tts_params = {} else: # Fallback for unsupported models tts_params = {} @@ -240,8 +294,6 @@ async def create_speech( tts_params.get("task_type", ["unknown"])[0], ) - sampling_params_list = self.engine_client.default_sampling_params_list - generator = self.engine_client.generate( prompt=prompt, request_id=request_id, @@ -278,7 +330,7 @@ async def create_speech( return self.create_error_response("TTS model did not produce audio output.") audio_tensor = audio_output[audio_key] - sample_rate = audio_output.get("sr", 24000) + sample_rate = audio_output.get("sr", default_sr) if hasattr(sample_rate, "item"): sample_rate = sample_rate.item() diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 4ebff32903..6d707d5bd9 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -117,7 +117,7 @@ def _convert_dataclasses_to_dict(obj: Any) -> Any: return obj -def resolve_model_config_path(model: str) -> str: +def resolve_model_config_path(model_type: str) -> str: """Resolve the stage config file path from the model name. Resolves stage configuration path based on the model type and device type. @@ -134,6 +134,29 @@ def resolve_model_config_path(model: str) -> str: ValueError: If model_type cannot be determined FileNotFoundError: If no stage config file exists for the model type """ + default_config_path = current_omni_platform.get_default_stage_config_path() + config_file_name = f"{model_type}.yaml" + complete_config_path = PROJECT_ROOT / default_config_path / config_file_name + if os.path.exists(complete_config_path): + return str(complete_config_path) + + # Fall back to default config + stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" + stage_config_path = PROJECT_ROOT / stage_config_file + if not os.path.exists(stage_config_path): + return None + return str(stage_config_path) + + +def resolve_model_type(model: str) -> str: + """Resolve the model type from the model name. + + Args: + model: Model name or path + + Returns: + Model type string if found, None otherwise + """ # Try to get config from standard transformers format first try: hf_config = get_config(model, trust_remote_code=True) @@ -165,21 +188,10 @@ def resolve_model_config_path(model: str) -> str: f"Please ensure the model has proper configuration files with 'model_type' field" ) - default_config_path = current_omni_platform.get_default_stage_config_path() - model_type_str = f"{model_type}.yaml" - complete_config_path = PROJECT_ROOT / default_config_path / model_type_str - if os.path.exists(complete_config_path): - return str(complete_config_path) + return model_type - # Fall back to default config - stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" - stage_config_path = PROJECT_ROOT / stage_config_file - if not os.path.exists(stage_config_path): - return None - return str(stage_config_path) - -def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list: +def load_stage_configs_from_model(config_path: str | None, base_engine_args: dict | None = None) -> list: """Load stage configurations from model's default config file. Loads stage configurations based on the model type and device type. @@ -187,7 +199,7 @@ def load_stage_configs_from_model(model: str, base_engine_args: dict | None = No directory. If not found, falls back to the default config file. Args: - model: Model name or path (used to determine model_type) + config_path: Path to the YAML configuration file Returns: List of stage configuration dictionaries @@ -197,10 +209,9 @@ def load_stage_configs_from_model(model: str, base_engine_args: dict | None = No """ if base_engine_args is None: base_engine_args = {} - stage_config_path = resolve_model_config_path(model) - if stage_config_path is None: + if config_path is None: return [] - stage_configs = load_stage_configs_from_yaml(config_path=stage_config_path, base_engine_args=base_engine_args) + stage_configs = load_stage_configs_from_yaml(config_path=config_path, base_engine_args=base_engine_args) return stage_configs