Skip to content

Commit 92f7208

Browse files
authored
[Bugfix] CosyVoice3: wrap ref_text in instruction template (#4644) (#4756)
1 parent 05a86ed commit 92f7208

3 files changed

Lines changed: 19 additions & 23 deletions

File tree

tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ def test_text2flow_token_only_strips_reference_speech_prefix_from_cumulative_ids
9191
assert outputs[0]["additional_information"]["ids"]["prompt"] == [10, 11]
9292

9393

94-
def test_text2flow_token_only_marks_prompt_trim_for_stop_token_completion():
94+
def test_text2flow_token_only_does_not_mark_prompt_trim():
95+
# The talker prompt is wrapped with the CosyVoice3 instruction template in
96+
# _build_cosyvoice3_prompt, so the talker emits target-only speech and no
97+
# prompt-trim offset is required; the flow stage trims prompt_feat itself
98+
# (issue #4644). Confirm no talker_prefill_offset is set.
9599
source_outputs = [
96100
_source_output(
97101
"req-stop",
@@ -104,7 +108,8 @@ def test_text2flow_token_only_marks_prompt_trim_for_stop_token_completion():
104108
outputs = text2flow_token_only(source_outputs=source_outputs, prompt=None)
105109

106110
assert outputs[0]["prompt_token_ids"] == [1, 2, 6562]
107-
assert outputs[0]["additional_information"]["meta"]["talker_prefill_offset"] == 2
111+
meta = outputs[0]["additional_information"].get("meta") or {}
112+
assert "talker_prefill_offset" not in meta
108113

109114

110115
def test_text2flow_full_payload_does_not_send_codec_ids():

vllm_omni/entrypoints/openai/serving_speech.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@
8282
_QWEN3_TTS_MODEL_STAGES = {"qwen3_tts"}
8383
_FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"}
8484
_COSYVOICE3_TTS_MODEL_STAGES = {"cosyvoice3_talker"}
85+
# CosyVoice3 talker expects its reference transcript wrapped in the model
86+
# instruction template; without the delimiter the talker re-speaks the
87+
# reference (issue #4644). Matches the offline example/test and upstream demo.
88+
_COSYVOICE3_PROMPT_DELIMITER = "<|endofprompt|>"
89+
_COSYVOICE3_PROMPT_PREFIX = f"You are a helpful assistant.{_COSYVOICE3_PROMPT_DELIMITER}"
8590
_OMNIVOICE_TTS_MODEL_STAGES = {"omnivoice_generator"}
8691
_COVO_AUDIO_MODEL_STAGES = {"fused_thinker_talker"}
8792
_VOXCPM2_TTS_MODEL_STAGES = {"latent_generator"}
@@ -3208,8 +3213,14 @@ async def _build_cosyvoice3_prompt(
32083213
wav_samples, sr = await self._resolve_ref_audio(request.ref_audio)
32093214
audio_data = (np.asarray(wav_samples, dtype=np.float32), sr)
32103215

3216+
# Wrap the reference transcript in the CosyVoice3 instruction template
3217+
# so the talker emits target-only speech (see _COSYVOICE3_PROMPT_PREFIX).
3218+
# Skip if the caller already supplied a formatted prompt_text.
3219+
ref_text = request.ref_text or ""
3220+
if _COSYVOICE3_PROMPT_DELIMITER not in ref_text:
3221+
ref_text = f"{_COSYVOICE3_PROMPT_PREFIX}{ref_text}"
32113222
mm_kwargs: dict[str, Any] = {
3212-
"prompt_text": request.ref_text,
3223+
"prompt_text": ref_text,
32133224
"sample_rate": sr,
32143225
}
32153226
# Pass voice metadata for caching in the processor

vllm_omni/model_executor/stage_input_processors/cosyvoice3.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
logger = init_logger(__name__)
2323

24-
_COSYVOICE3_SPEECH_TOKEN_SIZE = 6561
25-
2624

2725
def _build_prompt_embed_struct(prompt_payload: dict[str, Any]) -> EmbeddingsStruct | None:
2826
"""Wrap prompt_payload's flat speech_token/speech_feat/embedding tensors into EmbeddingsStruct."""
@@ -85,20 +83,6 @@ def _prompt_speech_token_ids(multi_modal_data: dict[str, Any]) -> list[int]:
8583
return _to_token_id_list(speech_token)
8684

8785

88-
def _has_speech_stop_token(output_ids: list[Any]) -> bool:
89-
return any(token_id >= _COSYVOICE3_SPEECH_TOKEN_SIZE for token_id in _to_token_id_list(output_ids))
90-
91-
92-
def _set_non_stream_prompt_trim(additional_info: dict[str, Any], prompt_speech_len: int) -> None:
93-
if prompt_speech_len <= 0:
94-
return
95-
meta = additional_info.get("meta")
96-
if not isinstance(meta, dict):
97-
meta = {}
98-
additional_info["meta"] = meta
99-
meta["talker_prefill_offset"] = prompt_speech_len
100-
101-
10286
def _to_cpu_tensor(x: Any) -> torch.Tensor | None:
10387
if isinstance(x, list):
10488
if not x:
@@ -154,8 +138,6 @@ def text2flow(
154138
output_ids = _strip_prompt_prefix(raw_output_ids, prefix_ids)
155139
output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids)
156140
additional_info = dict(multi_modal_data)
157-
if _has_speech_stop_token(raw_output_ids):
158-
_set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids))
159141
additional_info.setdefault("ids", {})["prompt"] = prefix_ids
160142
engine_inputs.append(OmniTokensPrompt(prompt_token_ids=output_ids, additional_information=additional_info))
161143
return engine_inputs
@@ -389,8 +371,6 @@ def text2flow_token_only(
389371
prompt_speech_ids = _prompt_speech_token_ids(multi_modal_data)
390372
output_ids = _strip_prompt_prefix(output_ids, prompt_speech_ids)
391373
additional_info: dict[str, Any] = dict(multi_modal_data)
392-
if _has_speech_stop_token(raw_output_ids):
393-
_set_non_stream_prompt_trim(additional_info, len(prompt_speech_ids))
394374
additional_info.setdefault("ids", {})["prompt"] = prefix_ids
395375
engine_inputs.append(
396376
OmniTokensPrompt(

0 commit comments

Comments
 (0)