Skip to content

Commit d483ebb

Browse files
authored
Merge branch 'main' into running-eval-suite-skill
2 parents 71e8a6e + ded5bba commit d483ebb

14 files changed

Lines changed: 314 additions & 146 deletions

File tree

.claude/skills/tune-ci-thresholds/models/s2-pro-v1/config.yaml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,8 @@
1010
# live in that same per-variant dir.
1111
# - Concurrency: CI uses --concurrency 8 (the only fully tuned conc); the
1212
# _VC_*_P95 dicts include {1,2,4,8,16}. Discover reads conc=8 row.
13-
# - Default venv path: /github/home/omni-s2pro to mirror the real CI venv
14-
# (test-s2pro-ci-v1.yaml). Critically, omni-s2pro must NOT have
15-
# `openai-whisper` / `whisper-normalizer` installed — CI's omni-s2pro
16-
# doesn't, which makes `_get_en_normalizer()` fall back to the
17-
# punctuation-strip path. omni-qwen3 (used by qwen3-omni-v1 stages) DOES
18-
# have openai-whisper installed and uses the real EnglishTextNormalizer,
19-
# so the two venvs produce different WER numbers on identical audio.
20-
# Keep them separate; do not unify.
13+
# - Venv: this host shares the omni-qwen3 venv across models; see
14+
# default_venv_python below.
2115
name: s2-pro-v1
2216
description: "FishAudio S2-Pro voice-clone TTS (v1 pipeline)"
2317
hf_model_id: "fishaudio/s2-pro"

benchmarks/tasks/tts.py

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919
import time
2020
import wave
2121
from dataclasses import dataclass
22-
from pathlib import Path
2322
from typing import Protocol
2423

2524
import aiohttp
2625
import soundfile as sf
2726
import torch
28-
import transformers
2927
from jiwer import process_words
3028
from tqdm import tqdm
3129

@@ -80,57 +78,17 @@ class SampleOutput:
8078

8179
@functools.lru_cache(maxsize=1)
8280
def _get_en_normalizer():
83-
"""Lazy-load the English text normalizer.
84-
85-
Tries whisper_normalizer (standalone pip package) first, then openai-whisper,
86-
then the transformers built-in normalizer.
87-
88-
note (Chenyang): The three fallbacks exist because our deployments don't always
89-
have whisper_normalizer installed, whisper's own normalizer lives under a
90-
different path depending on the release, and on minimal CI images we rely on
91-
the transformers copy bundled with the library. Keeping all three paths lets
92-
the WER numbers stay stable across environments (the official seed-tts-eval
93-
reference uses whisper_normalizer, so we prefer it when available).
94-
"""
95-
try:
96-
from whisper_normalizer.english import EnglishTextNormalizer
97-
98-
normalizer = EnglishTextNormalizer()
99-
logger.info("Using whisper_normalizer.english.EnglishTextNormalizer")
100-
return normalizer
101-
except ImportError:
102-
logger.debug("whisper_normalizer.english.EnglishTextNormalizer failed")
103-
81+
"""Lazy-load the required English WER normalizer from openai-whisper."""
10482
try:
10583
from whisper.normalizers import EnglishTextNormalizer
84+
except ImportError as exc:
85+
raise RuntimeError(
86+
"English WER requires openai-whisper "
87+
"(whisper.normalizers.EnglishTextNormalizer). "
88+
"Install pinned deps with `uv pip install -e .`."
89+
) from exc
10690

107-
normalizer = EnglishTextNormalizer()
108-
logger.info("Using whisper.normalizers.EnglishTextNormalizer")
109-
return normalizer
110-
except ImportError:
111-
logger.debug("whisper.normalizers.EnglishTextNormalizer failed")
112-
113-
try:
114-
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
115-
116-
json_path = (
117-
Path(transformers.__file__).parent / "models" / "whisper" / "english.json"
118-
)
119-
with open(json_path) as f:
120-
english_spelling_mapping = json.load(f)
121-
122-
normalizer = EnglishTextNormalizer(english_spelling_mapping)
123-
logger.info(
124-
"Using transformers.models.whisper.english_normalizer.EnglishTextNormalizer"
125-
)
126-
return normalizer
127-
except (ImportError, FileNotFoundError) as exc:
128-
logger.debug(f"transformers EnglishTextNormalizer failed: {exc}")
129-
130-
logger.warning(
131-
"EnglishTextNormalizer not found; falling back to punctuation-strip normalizer."
132-
)
133-
return None
91+
return EnglishTextNormalizer()
13492

13593

13694
def normalize_text(text: str, lang: str) -> str:
@@ -147,15 +105,7 @@ def normalize_text(text: str, lang: str) -> str:
147105
return text
148106

149107
normalizer = _get_en_normalizer()
150-
if normalizer is not None:
151-
return normalizer(text)
152-
153-
for ch in string.punctuation:
154-
if ch == "'":
155-
continue
156-
text = text.replace(ch, "")
157-
text = text.replace(" ", " ").strip().lower()
158-
return text
108+
return normalizer(text)
159109

160110

161111
def load_asr_model(lang: str, device: str, generation_mode: str | None = None):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ dependencies = [
4646
"pytest-asyncio>=0.21.0",
4747
"jiwer",
4848
"scipy>=1.10.0",
49-
"openai-whisper",
49+
"openai-whisper==20250625",
5050
# S2-Pro
5151
"tiktoken",
5252
"hydra-core",

sglang_omni_v1/models/fishaudio_s2_pro/bootstrap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def bootstrap_text_model_for_decode(
106106
audio_decoder: torch.nn.Module,
107107
semantic_begin_id: int,
108108
semantic_end_id: int,
109-
im_end_id: int,
109+
im_end_token_id: int,
110110
max_batch_size: int,
111111
num_codebooks: int,
112112
codebook_size: int,
@@ -119,6 +119,6 @@ def bootstrap_text_model_for_decode(
119119
codebook_size=codebook_size,
120120
semantic_begin_id=semantic_begin_id,
121121
semantic_end_id=semantic_end_id,
122-
im_end_id=im_end_id,
122+
im_end_token_id=im_end_token_id,
123123
max_batch_size=max_batch_size,
124124
)

sglang_omni_v1/models/fishaudio_s2_pro/fish_scheduler.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __init__(
236236
self, tree_cache: Any, im_end_token_id: int, max_new_tokens: int = 2048
237237
):
238238
self.tree_cache = tree_cache
239-
self._im_end_id = int(im_end_token_id)
239+
self._im_end_token_id = int(im_end_token_id)
240240
self._max_new_tokens = int(max_new_tokens)
241241

242242
def update_request(
@@ -250,7 +250,12 @@ def update_request(
250250
return
251251

252252
if output_token_id is not None:
253-
req.output_ids.append(int(output_token_id))
253+
semantic_token = int(output_token_id)
254+
req.output_ids.append(semantic_token)
255+
# Skip caching the terminal slow-AR EOS regardless of req.finished()
256+
# semantics: it is not an audio timestep and has no KV to preserve.
257+
if semantic_token == self._im_end_token_id:
258+
return
254259
if not req.finished() and req.decode_batch_idx == 0:
255260
self.tree_cache.cache_unfinished_req(req)
256261

@@ -265,7 +270,7 @@ def is_finished(
265270
if semantic_token is None and data.previous_semantic_tokens:
266271
semantic_token = int(data.previous_semantic_tokens[-1])
267272

268-
if semantic_token == self._im_end_id:
273+
if semantic_token == self._im_end_token_id:
269274
return True
270275

271276
max_tok = data.max_new_tokens or self._max_new_tokens
@@ -418,8 +423,20 @@ def emit_finished(self, finished: list[SchedulerRequest]) -> None:
418423
for request in finished:
419424
data = request.data
420425
data.output_ids = list(data.req.output_ids)
421-
result = self._result_adapter(data)
422426
t_submit = self._submit_times.pop(request.request_id, None)
427+
if not data.output_codes:
428+
self.outbox.put(
429+
OutgoingMessage(
430+
request_id=request.request_id,
431+
type="error",
432+
data=ValueError(
433+
f"Request {request.request_id}: "
434+
"S2-Pro generated no audio codec tokens"
435+
),
436+
)
437+
)
438+
continue
439+
result = self._result_adapter(data)
423440
if t_submit is not None and isinstance(result.data, dict):
424441
result.data["engine_time_s"] = time.perf_counter() - t_submit
425442
self.outbox.put(

sglang_omni_v1/models/fishaudio_s2_pro/model_runner.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,44 @@
1010
from sglang_omni_v1.model_runner.base import ModelRunner
1111

1212

13+
def collect_s2pro_step_outputs(
14+
result: Any,
15+
requests: list,
16+
*,
17+
output_codes: torch.Tensor,
18+
output_semantic_ids: torch.Tensor,
19+
im_end_token_id: int,
20+
) -> None:
21+
batch_size = len(requests)
22+
if batch_size == 0:
23+
return
24+
25+
result.next_token_ids = output_semantic_ids[:batch_size].clone()
26+
semantic_tokens = output_semantic_ids[:batch_size].tolist()
27+
28+
for row_idx, sched_req in enumerate(requests):
29+
data = sched_req.data
30+
if data.req.is_chunked > 0:
31+
continue
32+
33+
semantic_token = semantic_tokens[row_idx]
34+
if semantic_token == im_end_token_id:
35+
continue
36+
37+
codes = output_codes[row_idx].unsqueeze(-1).clone()
38+
data.last_codebook_values = codes[1:, 0].clone()
39+
data.previous_semantic_tokens.append(semantic_token)
40+
data.output_codes.append(codes)
41+
42+
1343
class FishS2ProModelRunner(ModelRunner):
1444
"""Fish TTS runner with unified forward-owned decode and persistent buffers."""
1545

1646
def __init__(self, tp_worker: Any, output_processor: Any):
1747
super().__init__(tp_worker, output_processor)
1848
self._semantic_begin_id = int(self.model._semantic_begin_id)
1949
self._semantic_end_id = int(self.model._semantic_end_id)
50+
self._im_end_token_id = int(self.model._im_end_token_id)
2051

2152
def prepare_prefill(self, forward_batch, schedule_batch, requests):
2253
del schedule_batch
@@ -117,19 +148,10 @@ def _build_prefill_input_embeds(
117148
return text_embeds
118149

119150
def _collect_step_outputs(self, result: Any, requests: list) -> None:
120-
batch_size = len(requests)
121-
if batch_size == 0:
122-
return
123-
124-
result.next_token_ids = self.model._output_semantic_ids[:batch_size].clone()
125-
126-
for row_idx, sched_req in enumerate(requests):
127-
data = sched_req.data
128-
req = data.req
129-
if req.is_chunked > 0:
130-
continue
131-
132-
codes = self.model._output_codes[row_idx].unsqueeze(-1).clone()
133-
data.last_codebook_values = codes[1:, 0].clone()
134-
data.previous_semantic_tokens.append(int(codes[0, -1].item()))
135-
data.output_codes.append(codes)
151+
collect_s2pro_step_outputs(
152+
result,
153+
requests,
154+
output_codes=self.model._output_codes,
155+
output_semantic_ids=self.model._output_semantic_ids,
156+
im_end_token_id=self._im_end_token_id,
157+
)

sglang_omni_v1/models/fishaudio_s2_pro/request_builders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,12 @@ def build_sglang_tts_request(
101101

102102

103103
def apply_tts_result(state: S2ProState, result: S2ProSGLangRequestData) -> None:
104-
if result.output_codes:
105-
state.output_codes = torch.cat(result.output_codes, dim=1)
106-
state.completion_tokens = state.output_codes.shape[1]
107-
else:
108-
state.output_codes = None
104+
assert result.output_codes, (
105+
"apply_tts_result expects non-empty output_codes; "
106+
"FishScheduler.emit_finished must filter immediate-EOS cases"
107+
)
108+
state.output_codes = torch.cat(result.output_codes, dim=1)
109+
state.completion_tokens = state.output_codes.shape[1]
109110
state.prompt_tokens = len(result.input_ids) if result.input_ids is not None else 0
110111

111112

sglang_omni_v1/models/fishaudio_s2_pro/sglang_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def setup_vq_decode(
242242
codebook_size: int,
243243
semantic_begin_id: int,
244244
semantic_end_id: int,
245-
im_end_id: int,
245+
im_end_token_id: int,
246246
max_batch_size: int,
247247
) -> None:
248248
"""Attach audio decoder and allocate persistent GPU buffers."""
@@ -254,6 +254,7 @@ def setup_vq_decode(
254254
self._num_codebooks = num_codebooks
255255
self._semantic_begin_id = semantic_begin_id
256256
self._semantic_end_id = semantic_end_id
257+
self._im_end_token_id = int(im_end_token_id)
257258

258259
# Shared codebook embedding from audio decoder (for VQ input combination)
259260
self._vq_codebook_embeddings = audio_decoder.codebook_embeddings
@@ -271,7 +272,7 @@ def setup_vq_decode(
271272
(self.vocab_size,), -float("inf"), device=device, dtype=torch.bfloat16
272273
)
273274
bias[semantic_begin_id : semantic_end_id + 1] = 0.0
274-
bias[im_end_id] = 0.0
275+
bias[im_end_token_id] = 0.0
275276
self._semantic_bias = bias
276277

277278
# Output buffers: written by _decode_codebooks, read by ModelRunner

sglang_omni_v1/models/fishaudio_s2_pro/stages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def create_sglang_tts_engine_executor(
248248
audio_decoder=audio_decoder,
249249
semantic_begin_id=adapter.semantic_begin_id,
250250
semantic_end_id=adapter.semantic_end_id,
251-
im_end_id=adapter.eos_token_ids[0],
251+
im_end_token_id=adapter.eos_token_ids[0],
252252
max_batch_size=server_args.max_running_requests,
253253
num_codebooks=num_codebooks,
254254
codebook_size=codebook_size,

sglang_omni_v1/models/qwen3_omni/request_builders.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sglang_omni_v1.proto import StagePayload
1919
from sglang_omni_v1.scheduling.messages import OutgoingMessage
2020
from sglang_omni_v1.scheduling.sglang_backend import SGLangARRequestData
21+
from sglang_omni_v1.scheduling.types import ARRequestData
2122

2223
IMAGE_STAGE = "image_encoder"
2324
AUDIO_STAGE = "audio_encoder"
@@ -32,10 +33,6 @@ class EncoderRequestData:
3233
skip_result: dict[str, Any] | None = None
3334

3435

35-
class ARRequestData:
36-
"""AR request data — base for SGLangARRequestData."""
37-
38-
3936
def build_encoder_request(
4037
state: PipelineState, *, stage_name: str
4138
) -> EncoderRequestData:

0 commit comments

Comments
 (0)