Skip to content

Commit af02862

Browse files
committed
Fix S2-Pro v1 streaming usage metrics
1 parent c1cd91b commit af02862

4 files changed

Lines changed: 62 additions & 1 deletion

File tree

sglang_omni_v1/models/fishaudio_s2_pro/fish_scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ def emit_finished(self, finished: list[SchedulerRequest]) -> None:
416416
for request in finished:
417417
data = request.data
418418
data.output_ids = list(data.req.output_ids)
419+
assert request.finish_time is not None
420+
data.stage_payload.data["engine_time_s"] = max(
421+
request.finish_time - request.arrival_time,
422+
1e-6,
423+
)
419424
result = self._result_adapter(data)
420425
self.outbox.put(
421426
OutgoingMessage(

sglang_omni_v1/models/fishaudio_s2_pro/request_builders.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,17 @@ def apply_tts_result(state: S2ProState, result: S2ProSGLangRequestData) -> None:
109109
state.prompt_tokens = len(result.input_ids) if result.input_ids is not None else 0
110110

111111

112+
def build_tts_usage(state: S2ProState) -> dict[str, Any]:
113+
usage = {
114+
"prompt_tokens": int(state.prompt_tokens),
115+
"completion_tokens": int(state.completion_tokens),
116+
"total_tokens": int(state.prompt_tokens + state.completion_tokens),
117+
}
118+
if state.engine_time_s > 0:
119+
usage["engine_time_s"] = round(float(state.engine_time_s), 6)
120+
return usage
121+
122+
112123
def make_tts_scheduler_adapters(*, tokenizer: Any):
113124
"""Build model-specific StagePayload <-> scheduler adapters for Fish TTS."""
114125

@@ -126,10 +137,13 @@ def result_adapter(data: S2ProSGLangRequestData) -> StagePayload:
126137
payload = data.stage_payload
127138
state = S2ProState.from_dict(payload.data)
128139
apply_tts_result(state, data)
140+
state.engine_time_s = float(payload.data["engine_time_s"])
141+
result_data = state.to_dict()
142+
result_data["usage"] = build_tts_usage(state)
129143
return StagePayload(
130144
request_id=payload.request_id,
131145
request=payload.request,
132-
data=state.to_dict(),
146+
data=result_data,
133147
)
134148

135149
return request_builder, result_adapter

sglang_omni_v1/models/fishaudio_s2_pro/stages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,11 @@ def _store_audio(
291291
state: S2ProState,
292292
audio_np: torch.Tensor,
293293
) -> StagePayload:
294+
usage = payload.data["usage"]
294295
state.audio_samples = audio_np
295296
state.sample_rate = codec.sample_rate
296297
payload = store_state(payload, state)
298+
payload.data["usage"] = usage
297299
payload.data["audio_data"] = audio_np.tolist()
298300
payload.data["sample_rate"] = codec.sample_rate
299301
payload.data["modality"] = "audio"

tests/test_v1_fish_vocoder_batch.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def _payload(request_id: str, code_len: int) -> StagePayload:
3636
)
3737

3838

39+
def _run_vocoder_request(scheduler, payload: StagePayload) -> StagePayload:
40+
thread = threading.Thread(target=scheduler.start, daemon=True)
41+
thread.start()
42+
try:
43+
scheduler.inbox.put(IncomingMessage(payload.request_id, "new_request", payload))
44+
output = scheduler.outbox.get(timeout=2.0)
45+
return output.data
46+
finally:
47+
scheduler.stop()
48+
thread.join(timeout=2.0)
49+
50+
3951
def test_fish_vocoder_uses_simple_scheduler_batch_path(monkeypatch) -> None:
4052
codec = _FakeCodec()
4153
monkeypatch.setattr(stages, "_resolve_checkpoint", lambda model_path: model_path)
@@ -69,3 +81,31 @@ def test_fish_vocoder_uses_simple_scheduler_batch_path(monkeypatch) -> None:
6981
assert len(outputs["req-long"].data["audio_data"]) == 12
7082
assert outputs["req-short"].data["audio_data"] == [1.0] * 8
7183
assert outputs["req-long"].data["audio_data"] == [2.0] * 12
84+
85+
86+
def test_fish_vocoder_preserves_existing_usage(monkeypatch) -> None:
87+
codec = _FakeCodec()
88+
monkeypatch.setattr(stages, "_resolve_checkpoint", lambda model_path: model_path)
89+
monkeypatch.setattr(stages, "_load_codec", lambda checkpoint, device: codec)
90+
91+
payload = _payload("req-usage", 2)
92+
usage = {
93+
"prompt_tokens": 3,
94+
"completion_tokens": 2,
95+
"total_tokens": 5,
96+
"engine_time_s": 0.25,
97+
}
98+
payload.data["usage"] = usage
99+
scheduler = stages.create_vocoder_executor(
100+
"unused",
101+
device="cpu",
102+
max_batch_size=1,
103+
max_batch_wait_ms=1,
104+
)
105+
106+
output = _run_vocoder_request(scheduler, payload)
107+
108+
assert output.data["usage"] == usage
109+
assert output.data["audio_data"] == [1.0] * 8
110+
assert output.data["sample_rate"] == 44100
111+
assert output.data["modality"] == "audio"

0 commit comments

Comments
 (0)