Skip to content

Add online serving to Stable Audio Diffusion TTS#1255

Open
ekagra-ranjan wants to merge 6 commits intovllm-project:mainfrom
ekagra-ranjan:er-stable-audio-online
Open

Add online serving to Stable Audio Diffusion TTS#1255
ekagra-ranjan wants to merge 6 commits intovllm-project:mainfrom
ekagra-ranjan:er-stable-audio-online

Conversation

@ekagra-ranjan
Copy link

Purpose

Add Stable Audio to online serving: #331 (comment)
As of now only Qwen3 TTS was supported on online serving.
This PR adds support for pure diffusion model TTS like Stable Audio to online serving.

1. Added Stable Audio-specific parameters to OpenAICreateSpeechRequest

extend protocol to stable audio specific params
file: vllm_omni/entrypoints/openai/protocol/audio.py

2. Serving Logic

file: vllm_omni/entrypoints/openai/serving_speech.py and vllm_omni/entrypoints/openai/api_server.py

  1. Sets the right parameters and sampling rate 44100
  2. adds .for_diffusion() to OmniOpenAIServingSpeech so future diffusion models can continue using it
  3. _DiffusionServingModels now mocks missing attributes like input_processor, model_config, renderer from OpenAIServingModels which is needed since OmniOpenAIServingSpeech inherits OpenAIServing. The mock allows not needing to update _DiffusionServingModels everytime we upgrade vllm version which can add new attributes to OpenAIServingModels, e.g. renderer was added newly in vllm 0.15

3. Documentation and Examples

Created complete example suite:

  • examples/online_serving/stable_audio/README.md - Full documentation
  • examples/online_serving/stable_audio/curl_examples.sh - Shell script examples
  • examples/online_serving/stable_audio/stable_audio_client.py - Python client

Test Plan

start the server and run curl_examples.sh

Test Result

dog_5s.wav
ocean.wav
thunder_rain.wav


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ea38d5aeb8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

request_id = f"speech-{random_uuid()}"

try:
sampling_params_list = self.engine_client.default_sampling_params_list

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Clone default sampling params before mutating request fields

create_speech now binds sampling_params_list directly to engine_client.default_sampling_params_list, then mutates entries for Stable Audio requests (generator, guidance_scale, num_inference_steps, extra_args). Because these defaults are shared across requests, one request’s overrides can leak into later requests (or race under concurrency), producing wrong seeds/durations/steps for unrelated calls; this should use per-request clones before applying overrides.

Useful? React with 👍 / 👎.


def _is_stable_audio_model(self) -> bool:
"""Check if the current model is a Stable Audio model."""
return "stabilityai/stable-audio-open" in self.model_name.lower()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Detect Stable Audio mode without hardcoded model-name substring

_is_stable_audio_model only matches "stabilityai/stable-audio-open" in self.model_name, so when the server uses --served-model-name aliases (or local/custom model identifiers), Stable Audio requests fall through to the generic branch and silently ignore Stable-specific fields like negative_prompt, audio_length, guidance_scale, and seed. This check should rely on model/stage capability instead of a fixed name substring.

Useful? React with 👍 / 👎.

@ekagra-ranjan
Copy link
Author

I just saw this comment. Lemme know which examples should I delete in this PR.

@linyueqian
Copy link
Contributor

i would not call it tts since it is an audio generation model, not capable of genrating speech.

@linyueqian
Copy link
Contributor

Tested locally with a local checkpoint path. The Stable Audio specific params (audio_length, seed, num_inference_steps, etc.) are all silently ignored.

_is_stable_audio_model() hardcodes "stabilityai/stable-audio-open" so it fails for local paths or --served-model-name aliases. The request falls through to the generic branch. Got 87s audio at 24000Hz instead of the requested 5s at 44100Hz.

Should detect model type from config/architecture instead of a name substring. Also sampling_params_list is a shared reference that gets mutated in place, will leak state across concurrent requests.

Separately, since Stable Audio is an audio generation model (not speech/TTS), should we serve it under a different endpoint like /v1/audio/generate instead of /v1/audio/speech? cc @hsliuustc0106 @Gaohan123 thoughts?

Comment on lines +246 to +249
elif self._is_stable_audio_model():
# Handle Stable Audio models
# Stable Audio uses diffusion, needs different parameters
default_sr = 44100 # Default sample rate for Stable Audio
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure how to merge this block with is_tts_model(). As of now that is_tts_model() block is very qwen3 specific with its prompt template and "additional_information" so I think there would be some model specific if-else but I dont know there is an existing standardization across the parameters and if that can be done later on when standardization happens on vllm-omni?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest moving the Stable Audio logic to a separate code path for now, like a diffusion-specific branch, rather than mixing it into the TTS flow. They're fundamentally different (autoregressive TTS vs diffusion gen) and trying to unify them now will be forced. We can revisit standardization later when we have a clearer picture.

@ekagra-ranjan
Copy link
Author

ekagra-ranjan commented Feb 7, 2026

Should detect model type from config/architecture instead of a name substring.

got it - makes sense what you are observing

Also sampling_params_list is a shared reference that gets mutated in place, will leak state across concurrent requests.

What is the path forward then if we cannot change the default sampling params? Okay, I see what you mean. I need to create a new object instead of changing the default object in-place.

…_type code across stage config loading. Avoid inplace change in default sampling arg

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
tokenizer = kwargs.get("tokenizer", None)

base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None
self.model_type = resolve_model_type(model)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this to get an identifier which can be relied on when local path to the model is used. After adding this, I realised that resolve_model_config_path() and load_stage_configs_from_model() share some operations so I refactored them to reuse the intermediate variables.

@linyueqian - Pls lmk if there was a better way to use an existing identifier in case I missed it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Qwen3-Omni and Qwen3-TTS, we get model type through engine_client.model_config.hf_config which vLLM populates from config.json at init time, so it works with local paths out of the box. Could you check if the same approach works here instead of adding a separate resolution step?

prompt["negative_prompt"] = request.negative_prompt

# Build sampling params for diffusion
sampling_params_list = [OmniDiffusionSamplingParams(num_outputs_per_prompt=1)]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@linyueqian - lmk if this resolves your comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is better, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants