Skip to content

Commit ded5bba

Browse files
[S2-Pro]: Fix S2-Pro terminal EOS audio frame (#377)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
1 parent b1697ce commit ded5bba

7 files changed

Lines changed: 278 additions & 30 deletions

File tree

sglang_omni_v1/models/fishaudio_s2_pro/bootstrap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def bootstrap_text_model_for_decode(
106106
audio_decoder: torch.nn.Module,
107107
semantic_begin_id: int,
108108
semantic_end_id: int,
109-
im_end_id: int,
109+
im_end_token_id: int,
110110
max_batch_size: int,
111111
num_codebooks: int,
112112
codebook_size: int,
@@ -119,6 +119,6 @@ def bootstrap_text_model_for_decode(
119119
codebook_size=codebook_size,
120120
semantic_begin_id=semantic_begin_id,
121121
semantic_end_id=semantic_end_id,
122-
im_end_id=im_end_id,
122+
im_end_token_id=im_end_token_id,
123123
max_batch_size=max_batch_size,
124124
)

sglang_omni_v1/models/fishaudio_s2_pro/fish_scheduler.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __init__(
236236
self, tree_cache: Any, im_end_token_id: int, max_new_tokens: int = 2048
237237
):
238238
self.tree_cache = tree_cache
239-
self._im_end_id = int(im_end_token_id)
239+
self._im_end_token_id = int(im_end_token_id)
240240
self._max_new_tokens = int(max_new_tokens)
241241

242242
def update_request(
@@ -250,7 +250,12 @@ def update_request(
250250
return
251251

252252
if output_token_id is not None:
253-
req.output_ids.append(int(output_token_id))
253+
semantic_token = int(output_token_id)
254+
req.output_ids.append(semantic_token)
255+
# Skip caching the terminal slow-AR EOS regardless of req.finished()
256+
# semantics: it is not an audio timestep and has no KV to preserve.
257+
if semantic_token == self._im_end_token_id:
258+
return
254259
if not req.finished() and req.decode_batch_idx == 0:
255260
self.tree_cache.cache_unfinished_req(req)
256261

@@ -265,7 +270,7 @@ def is_finished(
265270
if semantic_token is None and data.previous_semantic_tokens:
266271
semantic_token = int(data.previous_semantic_tokens[-1])
267272

268-
if semantic_token == self._im_end_id:
273+
if semantic_token == self._im_end_token_id:
269274
return True
270275

271276
max_tok = data.max_new_tokens or self._max_new_tokens
@@ -418,8 +423,20 @@ def emit_finished(self, finished: list[SchedulerRequest]) -> None:
418423
for request in finished:
419424
data = request.data
420425
data.output_ids = list(data.req.output_ids)
421-
result = self._result_adapter(data)
422426
t_submit = self._submit_times.pop(request.request_id, None)
427+
if not data.output_codes:
428+
self.outbox.put(
429+
OutgoingMessage(
430+
request_id=request.request_id,
431+
type="error",
432+
data=ValueError(
433+
f"Request {request.request_id}: "
434+
"S2-Pro generated no audio codec tokens"
435+
),
436+
)
437+
)
438+
continue
439+
result = self._result_adapter(data)
423440
if t_submit is not None and isinstance(result.data, dict):
424441
result.data["engine_time_s"] = time.perf_counter() - t_submit
425442
self.outbox.put(

sglang_omni_v1/models/fishaudio_s2_pro/model_runner.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,44 @@
1010
from sglang_omni_v1.model_runner.base import ModelRunner
1111

1212

13+
def collect_s2pro_step_outputs(
14+
result: Any,
15+
requests: list,
16+
*,
17+
output_codes: torch.Tensor,
18+
output_semantic_ids: torch.Tensor,
19+
im_end_token_id: int,
20+
) -> None:
21+
batch_size = len(requests)
22+
if batch_size == 0:
23+
return
24+
25+
result.next_token_ids = output_semantic_ids[:batch_size].clone()
26+
semantic_tokens = output_semantic_ids[:batch_size].tolist()
27+
28+
for row_idx, sched_req in enumerate(requests):
29+
data = sched_req.data
30+
if data.req.is_chunked > 0:
31+
continue
32+
33+
semantic_token = semantic_tokens[row_idx]
34+
if semantic_token == im_end_token_id:
35+
continue
36+
37+
codes = output_codes[row_idx].unsqueeze(-1).clone()
38+
data.last_codebook_values = codes[1:, 0].clone()
39+
data.previous_semantic_tokens.append(semantic_token)
40+
data.output_codes.append(codes)
41+
42+
1343
class FishS2ProModelRunner(ModelRunner):
1444
"""Fish TTS runner with unified forward-owned decode and persistent buffers."""
1545

1646
def __init__(self, tp_worker: Any, output_processor: Any):
1747
super().__init__(tp_worker, output_processor)
1848
self._semantic_begin_id = int(self.model._semantic_begin_id)
1949
self._semantic_end_id = int(self.model._semantic_end_id)
50+
self._im_end_token_id = int(self.model._im_end_token_id)
2051

2152
def prepare_prefill(self, forward_batch, schedule_batch, requests):
2253
del schedule_batch
@@ -117,19 +148,10 @@ def _build_prefill_input_embeds(
117148
return text_embeds
118149

119150
def _collect_step_outputs(self, result: Any, requests: list) -> None:
120-
batch_size = len(requests)
121-
if batch_size == 0:
122-
return
123-
124-
result.next_token_ids = self.model._output_semantic_ids[:batch_size].clone()
125-
126-
for row_idx, sched_req in enumerate(requests):
127-
data = sched_req.data
128-
req = data.req
129-
if req.is_chunked > 0:
130-
continue
131-
132-
codes = self.model._output_codes[row_idx].unsqueeze(-1).clone()
133-
data.last_codebook_values = codes[1:, 0].clone()
134-
data.previous_semantic_tokens.append(int(codes[0, -1].item()))
135-
data.output_codes.append(codes)
151+
collect_s2pro_step_outputs(
152+
result,
153+
requests,
154+
output_codes=self.model._output_codes,
155+
output_semantic_ids=self.model._output_semantic_ids,
156+
im_end_token_id=self._im_end_token_id,
157+
)

sglang_omni_v1/models/fishaudio_s2_pro/request_builders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,12 @@ def build_sglang_tts_request(
101101

102102

103103
def apply_tts_result(state: S2ProState, result: S2ProSGLangRequestData) -> None:
104-
if result.output_codes:
105-
state.output_codes = torch.cat(result.output_codes, dim=1)
106-
state.completion_tokens = state.output_codes.shape[1]
107-
else:
108-
state.output_codes = None
104+
assert result.output_codes, (
105+
"apply_tts_result expects non-empty output_codes; "
106+
"FishScheduler.emit_finished must filter immediate-EOS cases"
107+
)
108+
state.output_codes = torch.cat(result.output_codes, dim=1)
109+
state.completion_tokens = state.output_codes.shape[1]
109110
state.prompt_tokens = len(result.input_ids) if result.input_ids is not None else 0
110111

111112

sglang_omni_v1/models/fishaudio_s2_pro/sglang_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def setup_vq_decode(
242242
codebook_size: int,
243243
semantic_begin_id: int,
244244
semantic_end_id: int,
245-
im_end_id: int,
245+
im_end_token_id: int,
246246
max_batch_size: int,
247247
) -> None:
248248
"""Attach audio decoder and allocate persistent GPU buffers."""
@@ -254,6 +254,7 @@ def setup_vq_decode(
254254
self._num_codebooks = num_codebooks
255255
self._semantic_begin_id = semantic_begin_id
256256
self._semantic_end_id = semantic_end_id
257+
self._im_end_token_id = int(im_end_token_id)
257258

258259
# Shared codebook embedding from audio decoder (for VQ input combination)
259260
self._vq_codebook_embeddings = audio_decoder.codebook_embeddings
@@ -271,7 +272,7 @@ def setup_vq_decode(
271272
(self.vocab_size,), -float("inf"), device=device, dtype=torch.bfloat16
272273
)
273274
bias[semantic_begin_id : semantic_end_id + 1] = 0.0
274-
bias[im_end_id] = 0.0
275+
bias[im_end_token_id] = 0.0
275276
self._semantic_bias = bias
276277

277278
# Output buffers: written by _decode_codebooks, read by ModelRunner

sglang_omni_v1/models/fishaudio_s2_pro/stages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def create_sglang_tts_engine_executor(
248248
audio_decoder=audio_decoder,
249249
semantic_begin_id=adapter.semantic_begin_id,
250250
semantic_end_id=adapter.semantic_end_id,
251-
im_end_id=adapter.eos_token_ids[0],
251+
im_end_token_id=adapter.eos_token_ids[0],
252252
max_batch_size=server_args.max_running_requests,
253253
num_codebooks=num_codebooks,
254254
codebook_size=codebook_size,

0 commit comments

Comments
 (0)