diff --git a/examples/run_qwen3_omni_speech_server.py b/examples/run_qwen3_omni_speech_server.py index fa830bfb4..413159dd2 100644 --- a/examples/run_qwen3_omni_speech_server.py +++ b/examples/run_qwen3_omni_speech_server.py @@ -60,6 +60,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--relay-backend", type=str, default="shm", choices=["nixl", "shm"] ) + parser.add_argument( + "--mem-fraction-static", + type=float, + default=0.7, + help="Static memory fraction for SGLang-backed AR stages.", + ) # Server parser.add_argument("--host", type=str, default="0.0.0.0") @@ -91,6 +97,13 @@ async def main_async(args: argparse.Namespace) -> None: gpu_placement=gpu_placement, ) + server_args_overrides = {"mem_fraction_static": args.mem_fraction_static} + for stage in config.stages: + if stage.name in {"thinker", "talker_ar"}: + stage.executor.args.setdefault("server_args_overrides", {}).update( + server_args_overrides + ) + runner = MultiProcessPipelineRunner(config) logger.info("Starting 9-stage speech pipeline (multiprocess)...") await runner.start(timeout=600) diff --git a/playground/README.md b/playground/README.md index bde2b1a82..e4be0d02d 100644 --- a/playground/README.md +++ b/playground/README.md @@ -6,6 +6,7 @@ This directory contains multiple playground interfaces for SGLang-Omni. |---|---| | `web/` | Full-featured HTML/CSS/JS UI served directly by the sglang-omni server. Supports text, audio, image, video inputs and a built-in file browser. | | `gradio/` | Lightweight Gradio app that connects to a running server via HTTP. Text chat with streaming, model selector, and generation parameter controls. | +| `realtime-ws/` | Standalone websocket realtime app with server-side VAD, text input, microphone streaming, and streamed assistant audio playback. | | `tts/` | S2 Pro TTS Gradio app with shared controls for voice cloning plus separate streaming and non-streaming playback modes. | ## Web Playground @@ -20,6 +21,88 @@ uv pip install -v -e . Then open `http://localhost:8000` in your browser. +## Realtime WebSocket Playground + +Install the project before launching: + +```bash +uv pip install -v -e . +``` + +Launch the backend plus standalone frontend app with one command: + +```bash +./playground/realtime-ws/start.sh [--mock] [realtime-options] [backend-options...] +``` + +Minimal usable commands: + +```bash +# local smoke test +./playground/realtime-ws/start.sh --mock + +# real model +./playground/realtime-ws/start.sh --model-path Qwen/Qwen3-Omni-30B-A3B-Instruct +``` + +In normal backend mode, pass the usual speech server flags such as `--model-path`: + +```bash +./playground/realtime-ws/start.sh \ + --model-path Qwen/Qwen3-Omni-30B-A3B-Instruct +``` + +Then open `http://localhost:7862`. + +For a browser smoke test without loading any model, launch the mock realtime API: + +```bash +./playground/realtime-ws/start.sh --mock +``` + +That path exercises: + +- browser microphone capture over websocket PCM streaming +- server-side VAD turn detection +- automatic response start after speech stop +- streamed assistant audio playback in the browser +- text prompts over the same websocket session + +The mock backend returns canned text plus playback of the captured client audio +(falling back to a synthetic tone when there is no input audio) instead of +calling the inference pipeline. + +### Remote browser over SSH port forwarding + +Because the transport is plain HTTP + WebSocket, standard SSH forwarding is +enough for remote browser testing. + +Example: + +```bash +./playground/realtime-ws/start.sh --mock +``` + +Forward the backend port and the frontend port from the remote machine: + +```bash +ssh -L 8000:localhost:8000 -L 7862:localhost:7862 user@host +``` + +For the full launcher help, run: + +```bash +./playground/realtime-ws/start.sh --help +``` + +The websocket playground: + +- streams microphone PCM to the backend over `/v1/realtime/ws` +- runs server-side VAD to auto-trigger one inference turn per utterance +- supports manual push-to-talk and text prompts in the same session +- streams assistant audio back over the websocket and auto-plays it in the browser +- keeps the frontend separate from the inference API server + ### Custom port ```bash @@ -95,6 +178,7 @@ ssh -L 8000:localhost:8000 -L 7860:localhost:7860 user@host | `/` | Web playground UI (index.html, app.js, styles.css) | | `/v1/chat/completions` | Chat completions (text + audio, streaming) | | `/v1/audio/speech` | Text-to-speech | +| `/v1/realtime/ws` | Realtime websocket session transport | | `/v1/models` | List available models | | `/v1/fs/list` | Browse server filesystem | | `/v1/fs/file` | Download a server file | diff --git a/playground/realtime-ws/app.js b/playground/realtime-ws/app.js new file mode 100644 index 000000000..5893472a6 --- /dev/null +++ b/playground/realtime-ws/app.js @@ -0,0 +1,811 @@ +(function () { + "use strict"; + + const $ = (id) => document.getElementById(id); + const connectBtn = $("connect"); + const disconnectBtn = $("disconnect"); + const pushToTalkBtn = $("push-to-talk"); + const clearLogBtn = $("clear-log"); + const statusEl = $("status"); + const logEl = $("log"); + const conversationEl = $("conversation"); + const instructionsEl = $("instructions"); + const userPromptEl = $("user-prompt"); + const sendTextBtn = $("send-text"); + const audioModeAutoEl = $("audio-mode-auto"); + const audioModePushEl = $("audio-mode-push"); + const audioModeHelpEl = $("audio-mode-help"); + const micLevelFillEl = $("mic-level-fill"); + const micLevelTextEl = $("mic-level-text"); + const remoteLevelFillEl = $("remote-level-fill"); + const remoteLevelTextEl = $("remote-level-text"); + + let ws = null; + let localStream = null; + let sessionId = null; + let inputAudioMode = "vad"; + let pushToTalkActive = false; + let pushToTalkKeyActive = false; + let assistantMessages = new Map(); + + let captureAudioContext = null; + let captureSource = null; + let captureProcessor = null; + let captureSink = null; + let captureSampleRate = 16000; + + let playbackAudioContext = null; + let playbackCursor = 0; + let playbackGeneration = 0; + let playbackSources = new Set(); + let outputSampleRate = 24000; + let remoteLevel = 0; + let remoteLevelRaf = 0; + + function getApiBase() { + if ( + typeof window !== "undefined" && + Object.prototype.hasOwnProperty.call(window, "SGLANG_OMNI_API_BASE") + ) { + return String(window.SGLANG_OMNI_API_BASE || "").trim().replace(/\/$/, ""); + } + return window.location.origin; + } + + function buildWebSocketUrl() { + const base = new URL(getApiBase(), window.location.href); + base.protocol = base.protocol === "https:" ? "wss:" : "ws:"; + base.pathname = "/v1/realtime/ws"; + base.search = ""; + return base.toString(); + } + + function setStatus(text) { + statusEl.textContent = text; + } + + function log(message, payload) { + const stamp = new Date().toLocaleTimeString(); + const line = `[${stamp}] ${message}`; + const body = payload ? `\n${JSON.stringify(payload, null, 2)}` : ""; + logEl.textContent += `${line}${body}\n\n`; + logEl.scrollTop = logEl.scrollHeight; + } + + function scrollConversationToBottom() { + conversationEl.scrollTop = conversationEl.scrollHeight; + } + + function addConversationEmptyState() { + if (conversationEl.querySelector(".conversation-empty")) { + return; + } + const empty = document.createElement("div"); + empty.className = "conversation-empty"; + empty.textContent = "Conversation history will appear here."; + conversationEl.appendChild(empty); + } + + function clearConversation() { + assistantMessages = new Map(); + conversationEl.textContent = ""; + addConversationEmptyState(); + } + + function createConversationMessage(role, text, pending) { + const container = document.createElement("div"); + container.className = `message message-${role}${pending ? " pending" : ""}`; + + const roleEl = document.createElement("div"); + roleEl.className = "message-role"; + roleEl.textContent = role === "user" ? "User" : "Assistant"; + + const contentEl = document.createElement("div"); + contentEl.className = "message-content"; + contentEl.textContent = text; + + container.appendChild(roleEl); + container.appendChild(contentEl); + return { container, contentEl }; + } + + function appendConversationMessage(role, text, pending = false) { + const emptyState = conversationEl.querySelector(".conversation-empty"); + if (emptyState) { + emptyState.remove(); + } + const message = createConversationMessage(role, text, pending); + conversationEl.appendChild(message.container); + scrollConversationToBottom(); + return message; + } + + function ensureAssistantMessage(responseId) { + if (!responseId) { + return null; + } + const existing = assistantMessages.get(responseId); + if (existing) { + return existing; + } + const message = appendConversationMessage("assistant", "", true); + const entry = { + container: message.container, + contentEl: message.contentEl, + rawText: "", + hasVisibleText: false, + hasAudio: false, + }; + assistantMessages.set(responseId, entry); + return entry; + } + + function appendAssistantDelta(responseId, delta) { + if (!responseId || typeof delta !== "string" || delta.length === 0) { + return; + } + const entry = ensureAssistantMessage(responseId); + if (!entry) { + return; + } + entry.rawText += delta; + entry.contentEl.textContent = entry.rawText; + if (entry.rawText.trim().length > 0) { + entry.hasVisibleText = true; + } + scrollConversationToBottom(); + } + + function noteAssistantAudio(responseId) { + if (!responseId) { + return; + } + const entry = ensureAssistantMessage(responseId); + if (!entry) { + return; + } + entry.hasAudio = true; + if (!entry.hasVisibleText) { + entry.contentEl.textContent = "(streaming audio)"; + } + scrollConversationToBottom(); + } + + function finalizeAssistantMessage(responseId, text) { + if (!responseId) { + return; + } + const entry = ensureAssistantMessage(responseId); + if (!entry) { + return; + } + if (typeof text === "string" && text.length > 0) { + entry.rawText = text; + entry.contentEl.textContent = text; + entry.hasVisibleText = text.trim().length > 0; + } + if (!entry.hasVisibleText) { + entry.contentEl.textContent = entry.hasAudio + ? "(audio response)" + : "(no text output)"; + } + entry.container.classList.remove("pending"); + scrollConversationToBottom(); + } + + function updateMicLevel(level) { + const clamped = Math.max(0, Math.min(1, level)); + const percent = Math.round(clamped * 100); + micLevelFillEl.style.width = `${percent}%`; + micLevelTextEl.textContent = `${percent}%`; + } + + function updateRemoteLevel(level) { + const clamped = Math.max(0, Math.min(1, level)); + const percent = Math.round(clamped * 100); + remoteLevelFillEl.style.width = `${percent}%`; + remoteLevelTextEl.textContent = `${percent}%`; + } + + function tickRemoteLevelDecay() { + remoteLevel = Math.max(remoteLevel * 0.92, 0); + updateRemoteLevel(remoteLevel); + if (remoteLevel <= 0.01) { + remoteLevel = 0; + updateRemoteLevel(0); + remoteLevelRaf = 0; + return; + } + remoteLevelRaf = requestAnimationFrame(tickRemoteLevelDecay); + } + + function bumpRemoteLevel(level) { + remoteLevel = Math.max(remoteLevel, Math.max(0, Math.min(1, level))); + updateRemoteLevel(remoteLevel); + if (!remoteLevelRaf) { + remoteLevelRaf = requestAnimationFrame(tickRemoteLevelDecay); + } + } + + function clearPlaybackQueue() { + playbackGeneration += 1; + playbackCursor = 0; + playbackSources.forEach((source) => { + try { + source.stop(); + } catch (_) { + // Ignore already-ended nodes. + } + }); + playbackSources.clear(); + remoteLevel = 0; + updateRemoteLevel(0); + } + + function computeRms(samples) { + if (!samples || samples.length === 0) { + return 0; + } + let sum = 0; + for (let i = 0; i < samples.length; i += 1) { + const sample = samples[i]; + sum += sample * sample; + } + return Math.sqrt(sum / samples.length); + } + + function float32ToPcm16Bytes(samples) { + const pcm = new Int16Array(samples.length); + for (let i = 0; i < samples.length; i += 1) { + const sample = Math.max(-1, Math.min(1, samples[i])); + pcm[i] = sample < 0 ? Math.round(sample * 32768) : Math.round(sample * 32767); + } + return pcm.buffer; + } + + function pcm16ToFloat32(buffer) { + const int16 = new Int16Array(buffer); + const out = new Float32Array(int16.length); + for (let i = 0; i < int16.length; i += 1) { + out[i] = int16[i] / 32768; + } + return out; + } + + async function ensurePlaybackContext() { + if (!playbackAudioContext) { + const AudioContextCtor = window.AudioContext || window.webkitAudioContext; + if (!AudioContextCtor) { + throw new Error("Web Audio API is not supported"); + } + playbackAudioContext = new AudioContextCtor(); + } + if (playbackAudioContext.state === "suspended") { + await playbackAudioContext.resume(); + } + return playbackAudioContext; + } + + async function playAssistantAudioChunk(buffer) { + if (!(buffer instanceof ArrayBuffer) || buffer.byteLength === 0) { + return; + } + const audioContext = await ensurePlaybackContext(); + const samples = pcm16ToFloat32(buffer); + if (samples.length === 0) { + return; + } + + const audioBuffer = audioContext.createBuffer(1, samples.length, outputSampleRate); + audioBuffer.copyToChannel(samples, 0); + + const source = audioContext.createBufferSource(); + source.buffer = audioBuffer; + source.connect(audioContext.destination); + + const generation = playbackGeneration; + const startAt = Math.max(audioContext.currentTime + 0.04, playbackCursor || 0); + source.start(startAt); + playbackCursor = startAt + audioBuffer.duration; + playbackSources.add(source); + source.onended = () => { + playbackSources.delete(source); + if (generation !== playbackGeneration) { + return; + } + if (playbackSources.size === 0 && playbackAudioContext) { + playbackCursor = Math.max(playbackAudioContext.currentTime, 0); + } + }; + + bumpRemoteLevel(Math.min(computeRms(samples) * 4.0, 1.0)); + } + + function canSendControlEvent() { + return Boolean(ws && ws.readyState === WebSocket.OPEN); + } + + function sendControlEvent(payload) { + if (!canSendControlEvent()) { + log("control event skipped", { + type: payload && payload.type ? payload.type : "unknown", + reason: "websocket not open", + }); + return false; + } + ws.send(JSON.stringify(payload)); + return true; + } + + function updatePushToTalkUi() { + const manualMode = inputAudioMode === "manual"; + pushToTalkBtn.disabled = !(manualMode && canSendControlEvent()); + pushToTalkBtn.classList.toggle("hidden", !manualMode); + pushToTalkBtn.classList.toggle("active", pushToTalkActive); + pushToTalkBtn.textContent = pushToTalkActive + ? "Release To Commit" + : "Hold To Talk"; + } + + function updateTextPromptUi() { + const connected = canSendControlEvent(); + userPromptEl.disabled = !connected; + sendTextBtn.disabled = !(connected && userPromptEl.value.trim()); + } + + function updateAudioModeHelp() { + if (inputAudioMode === "manual") { + audioModeHelpEl.textContent = + "Push To Talk captures continuously but only commits audio while you hold the button or space bar."; + return; + } + audioModeHelpEl.textContent = + "Auto VAD streams microphone PCM continuously and lets the server detect utterance boundaries."; + } + + function setInputAudioMode(mode) { + inputAudioMode = mode === "manual" ? "manual" : "vad"; + audioModeAutoEl.checked = inputAudioMode === "vad"; + audioModePushEl.checked = inputAudioMode === "manual"; + if (inputAudioMode !== "manual") { + pushToTalkActive = false; + pushToTalkKeyActive = false; + } + updateAudioModeHelp(); + updatePushToTalkUi(); + } + + function sendSessionUpdate() { + const session = { + instructions: instructionsEl.value.trim(), + audio: { + input_mode: inputAudioMode, + }, + }; + return sendControlEvent({ + type: "session.update", + session, + }); + } + + function sendInputAudioFormat() { + return sendControlEvent({ + type: "input_audio_format", + sample_rate: captureSampleRate, + encoding: "pcm16le", + }); + } + + async function startAudioCapture(stream) { + stopAudioCapture(); + + const AudioContextCtor = window.AudioContext || window.webkitAudioContext; + if (!AudioContextCtor) { + throw new Error("Web Audio API is not supported"); + } + + captureAudioContext = new AudioContextCtor({ sampleRate: 16000 }); + if (captureAudioContext.state === "suspended") { + await captureAudioContext.resume(); + } + captureSampleRate = captureAudioContext.sampleRate; + + captureSource = captureAudioContext.createMediaStreamSource(stream); + captureProcessor = captureAudioContext.createScriptProcessor(2048, 1, 1); + captureSink = captureAudioContext.createGain(); + captureSink.gain.value = 0; + + captureProcessor.onaudioprocess = (event) => { + const input = event.inputBuffer.getChannelData(0); + const copy = new Float32Array(input.length); + copy.set(input); + updateMicLevel(Math.min(computeRms(copy) * 4.0, 1.0)); + if (!canSendControlEvent()) { + return; + } + ws.send(float32ToPcm16Bytes(copy)); + }; + + captureSource.connect(captureProcessor); + captureProcessor.connect(captureSink); + captureSink.connect(captureAudioContext.destination); + } + + function stopAudioCapture() { + if (captureProcessor) { + try { + captureProcessor.disconnect(); + } catch (_) { + // Ignore node teardown errors. + } + captureProcessor.onaudioprocess = null; + captureProcessor = null; + } + if (captureSink) { + try { + captureSink.disconnect(); + } catch (_) { + // Ignore node teardown errors. + } + captureSink = null; + } + if (captureSource) { + try { + captureSource.disconnect(); + } catch (_) { + // Ignore node teardown errors. + } + captureSource = null; + } + if (captureAudioContext) { + captureAudioContext.close().catch(() => {}); + captureAudioContext = null; + } + updateMicLevel(0); + } + + function handleServerEvent(event) { + if (!event || typeof event.type !== "string") { + return; + } + + if (event.type === "session.created") { + sessionId = event.session_id || null; + outputSampleRate = + (event.audio && Number(event.audio.output_sample_rate)) || outputSampleRate; + clearConversation(); + setStatus(sessionId ? `Connected (${sessionId})` : "Connected"); + disconnectBtn.disabled = false; + updatePushToTalkUi(); + updateTextPromptUi(); + return; + } + + if (event.type === "conversation.item.created") { + const item = event.item || {}; + if (item.role === "user" && typeof item.content === "string") { + appendConversationMessage("user", item.content); + } + return; + } + + if (event.type === "response.created") { + ensureAssistantMessage(event.response_id); + return; + } + + if (event.type === "response.output_text.delta") { + appendAssistantDelta(event.response_id, event.delta); + return; + } + + if (event.type === "response.done") { + finalizeAssistantMessage(event.response_id, event.text); + return; + } + + if (event.type === "response.output_audio.delta") { + noteAssistantAudio(event.response_id); + return; + } + + if (event.type === "response.cancelled") { + clearPlaybackQueue(); + const entry = assistantMessages.get(event.response_id); + if (entry) { + entry.container.classList.remove("pending"); + } + return; + } + + if (event.type === "session.updated") { + const session = event.session || {}; + const audio = session.audio || {}; + if (typeof audio.input_mode === "string") { + setInputAudioMode(audio.input_mode); + } + return; + } + + if (event.type === "output_audio_buffer.cleared") { + clearPlaybackQueue(); + return; + } + } + + function beginPushToTalk(source) { + if (inputAudioMode !== "manual" || pushToTalkActive) { + return; + } + if (!sendControlEvent({ type: "input_audio_buffer.start" })) { + updatePushToTalkUi(); + return; + } + pushToTalkActive = true; + updatePushToTalkUi(); + log("push-to-talk started", { source }); + } + + function commitPushToTalk(source) { + if (inputAudioMode !== "manual" || !pushToTalkActive) { + return; + } + pushToTalkActive = false; + updatePushToTalkUi(); + if (!sendControlEvent({ type: "input_audio_buffer.commit" })) { + return; + } + log("push-to-talk committed", { source }); + } + + function submitTextPrompt() { + const text = userPromptEl.value.trim(); + if (!text) { + updateTextPromptUi(); + return; + } + if ( + !sendControlEvent({ + type: "conversation.item.create", + item: { role: "user", content: text }, + }) + ) { + updateTextPromptUi(); + return; + } + if (!sendControlEvent({ type: "response.create" })) { + updateTextPromptUi(); + return; + } + userPromptEl.value = ""; + updateTextPromptUi(); + log("text prompt sent", { chars: text.length }); + } + + async function connect() { + if (ws) { + return; + } + connectBtn.disabled = true; + setStatus("Requesting microphone..."); + + localStream = await navigator.mediaDevices.getUserMedia({ + audio: { + echoCancellation: false, + noiseSuppression: false, + autoGainControl: false, + channelCount: 1, + }, + video: false, + }); + await startAudioCapture(localStream); + await ensurePlaybackContext(); + + const socket = new WebSocket(buildWebSocketUrl()); + socket.binaryType = "arraybuffer"; + ws = socket; + + socket.addEventListener("open", () => { + sendInputAudioFormat(); + sendSessionUpdate(); + setStatus("Connected"); + disconnectBtn.disabled = false; + updatePushToTalkUi(); + updateTextPromptUi(); + log("websocket opened", { + url: buildWebSocketUrl(), + capture_sample_rate: captureSampleRate, + }); + }); + + socket.addEventListener("message", (event) => { + if (typeof event.data === "string") { + try { + const payload = JSON.parse(event.data); + handleServerEvent(payload); + log("server event", payload); + } catch (err) { + log("server message", { raw: event.data }); + } + return; + } + playAssistantAudioChunk(event.data).catch((err) => { + log("audio playback error", { message: String(err) }); + }); + }); + + socket.addEventListener("error", () => { + log("websocket error", { url: buildWebSocketUrl() }); + }); + + socket.addEventListener("close", (event) => { + if (ws === socket) { + log("websocket closed", { + code: event.code, + reason: event.reason || "", + }); + disconnect(true).catch((err) => { + log("disconnect error", { message: String(err) }); + }); + } + }); + } + + async function disconnect(fromRemote = false) { + const socket = ws; + ws = null; + + if (!fromRemote && socket) { + try { + socket.close(); + } catch (_) { + // Ignore close errors. + } + } + + stopAudioCapture(); + if (localStream) { + localStream.getTracks().forEach((track) => track.stop()); + localStream = null; + } + if (playbackAudioContext) { + clearPlaybackQueue(); + try { + await playbackAudioContext.close(); + } catch (_) { + // Ignore context close errors. + } + playbackAudioContext = null; + } else { + clearPlaybackQueue(); + } + + sessionId = null; + pushToTalkActive = false; + pushToTalkKeyActive = false; + disconnectBtn.disabled = true; + connectBtn.disabled = false; + setStatus("Idle"); + updatePushToTalkUi(); + updateTextPromptUi(); + clearConversation(); + } + + connectBtn.addEventListener("click", async () => { + try { + await connect(); + } catch (err) { + log("connect error", { message: String(err) }); + await disconnect(true); + } + }); + + disconnectBtn.addEventListener("click", () => { + disconnect(false).catch((err) => { + log("disconnect error", { message: String(err) }); + }); + }); + + pushToTalkBtn.addEventListener("mousedown", (event) => { + event.preventDefault(); + beginPushToTalk("button"); + }); + pushToTalkBtn.addEventListener("mouseup", (event) => { + event.preventDefault(); + commitPushToTalk("button"); + }); + pushToTalkBtn.addEventListener("mouseleave", () => { + commitPushToTalk("button-leave"); + }); + pushToTalkBtn.addEventListener("touchstart", (event) => { + event.preventDefault(); + beginPushToTalk("touch"); + }); + pushToTalkBtn.addEventListener("touchend", (event) => { + event.preventDefault(); + commitPushToTalk("touch"); + }); + + window.addEventListener("keydown", (event) => { + if (inputAudioMode !== "manual") { + return; + } + if (event.code !== "Space" || event.repeat) { + return; + } + const tagName = event.target && event.target.tagName ? event.target.tagName : ""; + if (tagName === "TEXTAREA" || tagName === "INPUT") { + return; + } + event.preventDefault(); + pushToTalkKeyActive = true; + beginPushToTalk("space"); + }); + + window.addEventListener("keyup", (event) => { + if (inputAudioMode !== "manual") { + return; + } + if (event.code !== "Space" || !pushToTalkKeyActive) { + return; + } + event.preventDefault(); + pushToTalkKeyActive = false; + commitPushToTalk("space"); + }); + + window.addEventListener("blur", () => { + if (inputAudioMode !== "manual") { + return; + } + pushToTalkKeyActive = false; + commitPushToTalk("window-blur"); + }); + + audioModeAutoEl.addEventListener("change", () => { + if (!audioModeAutoEl.checked) { + return; + } + setInputAudioMode("vad"); + if (canSendControlEvent()) { + sendSessionUpdate(); + } + }); + + audioModePushEl.addEventListener("change", () => { + if (!audioModePushEl.checked) { + return; + } + setInputAudioMode("manual"); + if (canSendControlEvent()) { + sendSessionUpdate(); + } + }); + + clearLogBtn.addEventListener("click", () => { + logEl.textContent = ""; + }); + + sendTextBtn.addEventListener("click", () => { + submitTextPrompt(); + }); + + userPromptEl.addEventListener("input", () => { + updateTextPromptUi(); + }); + + userPromptEl.addEventListener("keydown", (event) => { + if (event.key !== "Enter" || event.shiftKey) { + return; + } + event.preventDefault(); + submitTextPrompt(); + }); + + updateMicLevel(0); + updateRemoteLevel(0); + clearConversation(); + setInputAudioMode("vad"); + updatePushToTalkUi(); + updateTextPromptUi(); +})(); diff --git a/playground/realtime-ws/app.py b/playground/realtime-ws/app.py new file mode 100644 index 000000000..b21c4c87a --- /dev/null +++ b/playground/realtime-ws/app.py @@ -0,0 +1,53 @@ +import argparse +import json +import os +from pathlib import Path + +import uvicorn +from fastapi import FastAPI +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.staticfiles import StaticFiles + +FRONTEND_DIR = Path(__file__).parent +assert FRONTEND_DIR.is_dir(), "Frontend directory does not exist" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="SGLang-Omni Realtime WS Playground") + parser.add_argument("--port", type=int, default=7862) + parser.add_argument("--api-base", type=str, default=None) + return parser.parse_args() + + +def create_app(api_base: str | None) -> FastAPI: + app = FastAPI(title="sglang-omni-realtime-ws-playground") + app.state.api_base = ( + api_base or os.environ.get("SGLANG_OMNI_API_BASE", "http://localhost:8000") + ).rstrip("/") + + @app.get("/") + async def index() -> HTMLResponse: + html = (FRONTEND_DIR / "index.html").read_text() + injection = ( + "" + ) + html = html.replace("", f"{injection}", 1) + return HTMLResponse(html) + + @app.get("/health") + async def health() -> JSONResponse: + return JSONResponse({"status": "ok", "api_base": app.state.api_base}) + + app.mount("/", StaticFiles(directory=str(FRONTEND_DIR), html=True)) + return app + + +def main() -> None: + args = parse_args() + uvicorn.run(create_app(args.api_base), host="0.0.0.0", port=args.port) + + +if __name__ == "__main__": + main() diff --git a/playground/realtime-ws/index.html b/playground/realtime-ws/index.html new file mode 100644 index 000000000..105025aed --- /dev/null +++ b/playground/realtime-ws/index.html @@ -0,0 +1,256 @@ + + + + + + SGLang-Omni Realtime WebSocket Prototype + + + +
+

Realtime WebSocket Prototype

+

+ Audio-first voice demo over a single WebSocket using raw PCM audio, + server-side VAD, and request-based inference on utterance end. +

+ +
+ + +
+
+
+ + +
+
+ Auto VAD streams microphone PCM continuously and lets the server detect utterance boundaries. +
+
+
+
+ + + + Idle +
+
+
+ Microphone level + 0% +
+ +
+
+
+ Assistant audio level + 0% +
+ +
+
+ +
+ Conversation +
+
Conversation history will appear here.
+
+ + +
+ + + Text turns share the same session history as voice turns. + +
+
+ +
+
+ Session Events + +
+
+
+
+ + + diff --git a/playground/realtime-ws/mock_server.py b/playground/realtime-ws/mock_server.py new file mode 100644 index 000000000..764e7b8d7 --- /dev/null +++ b/playground/realtime-ws/mock_server.py @@ -0,0 +1,91 @@ +import argparse + +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from sglang_omni.realtime.backend import MockResponseBackend +from sglang_omni.serve.realtime_ws_api import create_realtime_ws_router + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Mock realtime websocket API for browser smoke tests" + ) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model-name", type=str, default="mock-realtime-ws") + parser.add_argument( + "--response-text", + type=str, + default="Mock websocket backend replaying the captured utterance.", + ) + parser.add_argument( + "--audio-mode", + type=str, + choices=("playback", "tone", "echo"), + default="playback", + ) + parser.add_argument("--dump-audio-dir", type=str, default=None) + parser.add_argument("--sample-rate", type=int, default=24000) + parser.add_argument("--chunk-duration", type=float, default=0.24) + parser.add_argument("--chunk-delay", type=float, default=0.08) + parser.add_argument("--total-duration", type=float, default=1.2) + parser.add_argument("--tone-frequency", type=float, default=660.0) + return parser.parse_args() + + +def create_app(args: argparse.Namespace) -> FastAPI: + app = FastAPI(title="sglang-omni-realtime-ws-mock", version="0.1.0") + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + def backend_factory(model_name: str, max_new_tokens: int): + del max_new_tokens + return MockResponseBackend( + model=model_name, + output_modalities=("text", "audio"), + response_text=args.response_text, + audio_mode=args.audio_mode, + dump_audio_dir=args.dump_audio_dir, + sample_rate=args.sample_rate, + chunk_duration_s=args.chunk_duration, + inter_chunk_delay_s=args.chunk_delay, + total_duration_s=args.total_duration, + tone_hz=args.tone_frequency, + ) + + app.include_router( + create_realtime_ws_router( + model_name=args.model_name, + backend_factory=backend_factory, + ) + ) + + @app.get("/health") + async def health() -> JSONResponse: + return JSONResponse( + { + "status": "healthy", + "running": True, + "backend": "mock-realtime-ws", + "model": args.model_name, + } + ) + + return app + + +def main() -> None: + args = parse_args() + uvicorn.run(create_app(args), host="0.0.0.0", port=args.port) + + +if __name__ == "__main__": + main() diff --git a/playground/realtime-ws/start.sh b/playground/realtime-ws/start.sh new file mode 100755 index 000000000..be8f71901 --- /dev/null +++ b/playground/realtime-ws/start.sh @@ -0,0 +1,171 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-python}" +BACKEND_ENTRY="${REPO_ROOT}/examples/run_qwen3_omni_speech_server.py" + +BACKEND_PORT="${PORT:-8000}" +PLAYGROUND_PORT="7862" +MOCK_BACKEND="0" +MOCK_ARGS=() +BACKEND_ARGS=() + +usage() { + cat <<'EOF' +Usage: + ./playground/realtime-ws/start.sh [--mock] [realtime-options] [backend-options...] + +Description: + Launch the standalone websocket realtime frontend plus either: + - the mock websocket backend with --mock, or + - the Qwen3 Omni speech server backend with forwarded backend-options. + +Realtime options: + --mock Use the mock backend instead of the Qwen3 Omni speech server. + --port PORT Backend API port. Default: 8000. + --playground-port PORT Frontend UI port. Default: 7862. + +Mock-only options: + --response-text TEXT + --audio-mode MODE + --dump-audio-dir DIR + --model-name NAME + --sample-rate HZ + --chunk-duration SECONDS + --chunk-delay SECONDS + --total-duration SECONDS + --tone-frequency HZ + +Backend options: + Any unrecognized options are forwarded to: + python examples/run_qwen3_omni_speech_server.py + +Examples: + ./playground/realtime-ws/start.sh --mock + ./playground/realtime-ws/start.sh --model-path Qwen/Qwen3-Omni-30B-A3B-Instruct +EOF +} + +pick_free_port() { + "${PYTHON_BIN}" - <<'PY' +import socket + +s = socket.socket() +s.bind(("", 0)) +print(s.getsockname()[1]) +s.close() +PY +} + +require_websocket_runtime() { + "${PYTHON_BIN}" - <<'PY' +import importlib.util +import sys + +if importlib.util.find_spec("websockets") or importlib.util.find_spec("wsproto"): + raise SystemExit(0) +print( + "ERROR: WebSocket runtime support is missing. Install project dependencies with\n" + " uv pip install -e .\n" + "or install one of:\n" + " pip install websockets\n" + " pip install wsproto" +) +raise SystemExit(1) +PY +} + +while [[ $# -gt 0 ]]; do + case "$1" in + -h|--help) usage; exit 0 ;; + --mock) MOCK_BACKEND="1"; shift ;; + --port) BACKEND_PORT="$2"; shift 2 ;; + --playground-port) PLAYGROUND_PORT="$2"; shift 2 ;; + --response-text|--audio-mode|--dump-audio-dir|--model-name|--sample-rate|--chunk-duration|--chunk-delay|--total-duration|--tone-frequency) + MOCK_ARGS+=("$1" "$2"); shift 2 ;; + --pipeline) shift 2 ;; + *) BACKEND_ARGS+=("$1"); shift ;; + esac +done + +if [[ "${MOCK_BACKEND}" != "1" && ${#BACKEND_ARGS[@]} -eq 0 ]]; then + usage + exit 1 +fi + +if ! "${PYTHON_BIN}" -c "import socket; s=socket.socket(); s.bind(('0.0.0.0',${BACKEND_PORT})); s.close()" 2>/dev/null; then + echo "WARNING: Port ${BACKEND_PORT} is already in use." + BACKEND_PORT=$(pick_free_port) + echo "Using port ${BACKEND_PORT} instead." +fi + +API_BASE="http://localhost:${BACKEND_PORT}" +if ! require_websocket_runtime; then + exit 1 +fi + +cleanup() { + if [[ -n "${SERVER_PID:-}" ]]; then + kill "${SERVER_PID}" 2>/dev/null || true + wait "${SERVER_PID}" 2>/dev/null || true + fi +} +trap cleanup EXIT INT TERM + +echo "============================================================" +echo " SGLang-Omni Realtime WebSocket Playground" +echo "============================================================" +echo "" +echo " Backend API: ${API_BASE}" +echo " Frontend UI: http://localhost:${PLAYGROUND_PORT}" +echo "" +echo "============================================================" +echo "" + +if [[ "${MOCK_BACKEND}" == "1" ]]; then + echo "[1/2] Starting mock websocket realtime API server..." + "${PYTHON_BIN}" "${SCRIPT_DIR}/mock_server.py" \ + --port "${BACKEND_PORT}" \ + "${MOCK_ARGS[@]}" & +else + echo "[1/2] Starting backend server with arguments: ${BACKEND_ARGS[*]}" + "${PYTHON_BIN}" "${BACKEND_ENTRY}" \ + "${BACKEND_ARGS[@]}" \ + --port "${BACKEND_PORT}" & +fi +SERVER_PID=$! + +echo "[2/2] Waiting for server to be ready..." +for i in $(seq 1 120); do + if ! kill -0 "${SERVER_PID}" 2>/dev/null; then + echo "ERROR: Backend server exited unexpectedly." + exit 1 + fi + if curl -s "${API_BASE}/health" 2>/dev/null | grep -q "healthy"; then + echo "Server is ready." + break + fi + if [[ $i -eq 120 ]]; then + echo "ERROR: Server did not become healthy within 600s." + exit 1 + fi + sleep 5 +done + +echo "" +echo "============================================================" +echo " Server is ready!" +echo "============================================================" +echo "" +echo " Frontend UI: http://localhost:${PLAYGROUND_PORT}" +echo " Backend API: ${API_BASE}" +echo "" +echo "============================================================" +echo "" + +export SGLANG_OMNI_API_BASE="${API_BASE}" +"${PYTHON_BIN}" "${SCRIPT_DIR}/app.py" \ + --api-base "${API_BASE}" \ + --port "${PLAYGROUND_PORT}" diff --git a/pyproject.toml b/pyproject.toml index 0e02f3f5b..82b29610b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,9 @@ dependencies = [ "torchaudio", # Gradio playground "gradio>=4.0.0", + # Realtime websocket playground + "webrtcvad-wheels>=2.0.14", + "websockets>=12.0", # flash-attn: install separately via prebuilt wheel (see install instructions below) ] diff --git a/sglang_omni/client/client.py b/sglang_omni/client/client.py index 92e149014..5cd4a6f84 100644 --- a/sglang_omni/client/client.py +++ b/sglang_omni/client/client.py @@ -411,21 +411,27 @@ def _extract_inputs(request: GenerateRequest) -> Any: # Build messages list messages = [msg.to_dict() for msg in request.messages or []] - # Check if we have audios, images, or videos in metadata - audios = request.metadata.get("audios") - images = request.metadata.get("images") - videos = request.metadata.get("videos") - - # If we have any media, return a dict with messages and media - # Otherwise, return just the messages list (for backward compatibility) - if audios or images or videos: + media_input_keys = ( + "audios", + "images", + "videos", + "audio_target_sr", + "video_fps", + "use_audio_in_video", + "video_seconds_per_chunk", + "video_position_id_per_seconds", + ) + media_inputs = { + key: request.metadata.get(key) + for key in media_input_keys + if request.metadata.get(key) is not None + } + + # If we have any media or media preprocessing hints, return a dict with + # messages plus those fields. Otherwise keep the legacy bare-messages shape. + if media_inputs: result = {"messages": messages} - if images: - result["images"] = images - if audios: - result["audios"] = audios - if videos: - result["videos"] = videos + result.update(media_inputs) return result return messages diff --git a/sglang_omni/pipeline/coordinator.py b/sglang_omni/pipeline/coordinator.py index 689b25948..0a2318a62 100644 --- a/sglang_omni/pipeline/coordinator.py +++ b/sglang_omni/pipeline/coordinator.py @@ -2,6 +2,7 @@ """Coordinator for managing the multi-stage pipeline.""" import asyncio +import contextlib import logging from typing import Any, AsyncIterator @@ -144,8 +145,18 @@ async def stream( else: yield msg finally: + info = self._requests.get(request_id) + if info is not None and info.state not in { + RequestState.COMPLETED, + RequestState.FAILED, + RequestState.ABORTED, + }: + await self.abort(request_id) self._stream_queues.pop(request_id, None) - self._completion_futures.pop(request_id, None) + future = self._completion_futures.pop(request_id, None) + if future is not None and future.done(): + with contextlib.suppress(asyncio.CancelledError, Exception): + future.exception() async def _submit_request( self, request_id: str, request: OmniRequest | Any diff --git a/sglang_omni/preprocessing/audio.py b/sglang_omni/preprocessing/audio.py index 59f90f4c7..14cefb862 100644 --- a/sglang_omni/preprocessing/audio.py +++ b/sglang_omni/preprocessing/audio.py @@ -197,6 +197,33 @@ def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: return resampled, float(self.target_sr) +def _decode_audio_waveform_payload( + item: dict[str, Any], + *, + target_sr: int, +) -> np.ndarray: + raw = item.get("audio_waveform") + if raw is None: + raise ValueError("Missing audio_waveform payload") + if isinstance(raw, memoryview): + raw = raw.tobytes() + if not isinstance(raw, (bytes, bytearray)): + raise TypeError("audio_waveform payload must be bytes-like") + + dtype = np.dtype(item.get("audio_waveform_dtype", "float32")) + audio = np.frombuffer(raw, dtype=dtype) + shape = item.get("audio_waveform_shape") + if shape: + audio = audio.reshape(shape) + audio = audio.astype(np.float32, copy=False) + + source_sr = item.get("sample_rate") + if isinstance(source_sr, int) and source_sr > 0 and source_sr != target_sr: + audio = _resample_linear(audio, source_sr, target_sr) + + return audio.copy() + + async def ensure_audio_list_async( audios: Any, *, @@ -244,6 +271,8 @@ async def ensure_audio_list_async( else: # Local path - can be loaded synchronously normalized.append(load_audio_path(item, target_sr=target_sr)) + elif isinstance(item, dict) and item.get("audio_waveform") is not None: + normalized.append(_decode_audio_waveform_payload(item, target_sr=target_sr)) else: # Already processed (numpy array, etc.) normalized.append(item) diff --git a/sglang_omni/realtime/__init__.py b/sglang_omni/realtime/__init__.py new file mode 100644 index 000000000..b520d2218 --- /dev/null +++ b/sglang_omni/realtime/__init__.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Realtime helpers shared by the interactive transports.""" + +from sglang_omni.realtime.backend import ( + BackendCapabilities, + MockResponseBackend, + OmniResponseBackend, + ResponseBackend, + ResponseEvent, + TurnContext, +) +from sglang_omni.realtime.session import RealtimeSession, RealtimeSessionConfig +from sglang_omni.realtime.vad import EnergyVad, VadConfig + +__all__ = [ + "BackendCapabilities", + "EnergyVad", + "MockResponseBackend", + "OmniResponseBackend", + "RealtimeSession", + "RealtimeSessionConfig", + "ResponseBackend", + "ResponseEvent", + "TurnContext", + "VadConfig", +] diff --git a/sglang_omni/realtime/backend/__init__.py b/sglang_omni/realtime/backend/__init__.py new file mode 100644 index 000000000..6b0288e60 --- /dev/null +++ b/sglang_omni/realtime/backend/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Realtime response backends.""" + +from sglang_omni.realtime.backend.base import ( + BackendCapabilities, + ResponseBackend, + ResponseEvent, + TurnContext, +) +from sglang_omni.realtime.backend.mock import MockResponseBackend +from sglang_omni.realtime.backend.omni import OmniResponseBackend + +__all__ = [ + "BackendCapabilities", + "MockResponseBackend", + "OmniResponseBackend", + "ResponseBackend", + "ResponseEvent", + "TurnContext", +] diff --git a/sglang_omni/realtime/backend/base.py b/sglang_omni/realtime/backend/base.py new file mode 100644 index 000000000..09f59427d --- /dev/null +++ b/sglang_omni/realtime/backend/base.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Backend abstraction for realtime turn responses.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from dataclasses import dataclass +from typing import Any, Protocol + + +@dataclass(frozen=True) +class BackendCapabilities: + accepts_audio_input: bool = False + accepts_video_input: bool = False + returns_text: bool = True + returns_audio: bool = False + supports_cancel: bool = True + + +@dataclass +class TurnContext: + session_id: str + history: list[dict[str, str]] + instructions: str | None + user_text: str | None + user_audio: Any | None + user_audio_sample_rate: int | None + recent_video: Any | None + recent_video_fps: float | None + turn_index: int | None = None + + +@dataclass +class ResponseEvent: + type: str + response_id: str + text: str | None = None + audio: Any | None = None + sample_rate: int | None = None + finish_reason: str | None = None + error: str | None = None + + +class ResponseBackend(Protocol): + @property + def model_name(self) -> str: ... + + @property + def capabilities(self) -> BackendCapabilities: ... + + async def stream_response( + self, + turn: TurnContext, + ) -> AsyncIterator[ResponseEvent]: ... + + async def cancel(self, response_id: str) -> None: ... diff --git a/sglang_omni/realtime/backend/mock.py b/sglang_omni/realtime/backend/mock.py new file mode 100644 index 000000000..13761b786 --- /dev/null +++ b/sglang_omni/realtime/backend/mock.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Mock response backend for browser smoke tests.""" + +from __future__ import annotations + +import asyncio +import uuid +from collections.abc import AsyncIterator +from pathlib import Path + +import numpy as np +import soundfile as sf + +from sglang_omni.realtime.backend.base import ( + BackendCapabilities, + ResponseBackend, + ResponseEvent, + TurnContext, +) + + +class MockResponseBackend(ResponseBackend): + """Stream mock text plus playback, conditioned echo, or a test tone.""" + + def __init__( + self, + *, + model: str = "mock-realtime", + output_modalities: tuple[str, ...] = ("text", "audio"), + response_text: str = "Mock backend replaying the captured utterance.", + audio_mode: str = "playback", + dump_audio_dir: str | None = None, + sample_rate: int = 24000, + chunk_duration_s: float = 0.24, + inter_chunk_delay_s: float = 0.08, + total_duration_s: float = 1.2, + tone_hz: float = 660.0, + ) -> None: + self._model = model + self._output_modalities = output_modalities + self._response_text = response_text.strip() or "Mock backend response." + self._audio_mode = audio_mode + self._dump_audio_dir = ( + Path(dump_audio_dir).expanduser() if dump_audio_dir else None + ) + self._sample_rate = sample_rate + self._chunk_duration_s = chunk_duration_s + self._inter_chunk_delay_s = inter_chunk_delay_s + self._total_duration_s = total_duration_s + self._tone_hz = tone_hz + self._cancel_events: dict[str, asyncio.Event] = {} + if self._audio_mode not in {"playback", "tone", "echo"}: + raise ValueError( + "Unsupported mock audio mode " + f"{self._audio_mode!r}; expected 'playback', 'echo', or 'tone'." + ) + if self._dump_audio_dir is not None: + self._dump_audio_dir.mkdir(parents=True, exist_ok=True) + self._capabilities = BackendCapabilities( + accepts_audio_input=True, + accepts_video_input=True, + returns_text="text" in output_modalities, + returns_audio="audio" in output_modalities, + supports_cancel=True, + ) + + @property + def model_name(self) -> str: + return self._model + + @property + def capabilities(self) -> BackendCapabilities: + return self._capabilities + + async def stream_response( + self, + turn: TurnContext, + ) -> AsyncIterator[ResponseEvent]: + response_id = uuid.uuid4().hex + cancel_event = asyncio.Event() + self._cancel_events[response_id] = cancel_event + try: + yield ResponseEvent(type="response_started", response_id=response_id) + self._dump_captured_audio(turn, response_id=response_id) + + if self._capabilities.returns_text: + for text_delta in self._split_text(self._response_text): + if cancel_event.is_set(): + yield ResponseEvent( + type="done", + response_id=response_id, + finish_reason="cancelled", + ) + return + yield ResponseEvent( + type="text_delta", + response_id=response_id, + text=text_delta, + ) + + if self._capabilities.returns_audio: + audio_chunks, sample_rate = self._build_audio_chunks(turn) + for chunk in audio_chunks: + if cancel_event.is_set(): + yield ResponseEvent( + type="done", + response_id=response_id, + finish_reason="cancelled", + ) + return + yield ResponseEvent( + type="audio_chunk", + response_id=response_id, + audio=chunk, + sample_rate=sample_rate, + ) + if self._inter_chunk_delay_s > 0: + await asyncio.sleep(self._inter_chunk_delay_s) + + finish_reason = "cancelled" if cancel_event.is_set() else "stop" + yield ResponseEvent( + type="done", + response_id=response_id, + finish_reason=finish_reason, + ) + finally: + self._cancel_events.pop(response_id, None) + + async def cancel(self, response_id: str) -> None: + event = self._cancel_events.get(response_id) + if event is not None: + event.set() + + def _split_text(self, text: str) -> list[str]: + parts = [segment.strip() for segment in text.split(".") if segment.strip()] + if not parts: + return [text] + return [f"{part}. " for part in parts[:-1]] + [f"{parts[-1]}."] + + def _build_audio_chunks( + self, + turn: TurnContext, + ) -> tuple[list[np.ndarray], int]: + waveform, sample_rate = self._resolve_response_audio(turn) + total_samples = int(waveform.size) + chunk_samples = max(1, int(round(sample_rate * self._chunk_duration_s))) + return [ + waveform[start : start + chunk_samples] + for start in range(0, total_samples, chunk_samples) + ], sample_rate + + def _resolve_response_audio(self, turn: TurnContext) -> tuple[np.ndarray, int]: + if self._audio_mode == "playback" and turn.user_audio is not None: + waveform = np.asarray(turn.user_audio, dtype=np.float32).reshape(-1) + if waveform.size > 0: + sample_rate = int(turn.user_audio_sample_rate or self._sample_rate) + return waveform.astype(np.float32, copy=False), sample_rate + + if self._audio_mode == "echo" and turn.user_audio is not None: + waveform = np.asarray(turn.user_audio, dtype=np.float32).reshape(-1) + if waveform.size > 0: + sample_rate = int(turn.user_audio_sample_rate or self._sample_rate) + return self._condition_echo_waveform(waveform), sample_rate + + sample_rate = self._sample_rate + total_samples = max(1, int(round(sample_rate * self._total_duration_s))) + waveform = self._build_demo_waveform(total_samples) + return waveform, sample_rate + + def _condition_echo_waveform(self, waveform: np.ndarray) -> np.ndarray: + waveform = np.asarray(waveform, dtype=np.float32).reshape(-1) + if waveform.size == 0: + return waveform + + # Browser / device capture can carry a DC bias and occasional hot peaks. + # Remove the bias and cap the peak so the mock echo stays intelligible. + waveform = waveform - float(np.mean(waveform)) + peak = float(np.max(np.abs(waveform))) if waveform.size else 0.0 + target_peak = 0.35 + if peak > target_peak and peak > 0.0: + waveform = waveform * (target_peak / peak) + return np.clip(waveform, -0.95, 0.95).astype(np.float32, copy=False) + + def _dump_captured_audio(self, turn: TurnContext, *, response_id: str) -> None: + if self._dump_audio_dir is None or turn.user_audio is None: + return + + waveform = np.asarray(turn.user_audio, dtype=np.float32).reshape(-1) + if waveform.size == 0: + return + + sample_rate = int(turn.user_audio_sample_rate or self._sample_rate) + if turn.turn_index is not None: + turn_dir = ( + self._dump_audio_dir + / turn.session_id + / f"turn_{int(turn.turn_index):04d}" + ) + turn_dir.mkdir(parents=True, exist_ok=True) + path = ( + turn_dir / f"04_backend_turn_context_{response_id}_sr{sample_rate}.wav" + ) + else: + path = ( + self._dump_audio_dir / f"{turn.session_id}_{response_id}_captured.wav" + ) + sf.write(path, waveform, sample_rate) + + def _build_demo_waveform(self, num_samples: int) -> np.ndarray: + t = np.arange(num_samples, dtype=np.float32) / float(self._sample_rate) + carrier = np.sin(2.0 * np.pi * self._tone_hz * t) + modulator = 0.55 + 0.45 * np.sin(2.0 * np.pi * 2.0 * t) + waveform = 0.18 * carrier * modulator + + fade_samples = min(num_samples // 8, max(1, self._sample_rate // 50)) + if fade_samples > 0: + fade_in = np.linspace(0.0, 1.0, fade_samples, dtype=np.float32) + fade_out = fade_in[::-1] + waveform[:fade_samples] *= fade_in + waveform[-fade_samples:] *= fade_out + return waveform.astype(np.float32, copy=False) diff --git a/sglang_omni/realtime/backend/omni.py b/sglang_omni/realtime/backend/omni.py new file mode 100644 index 000000000..4e845f50e --- /dev/null +++ b/sglang_omni/realtime/backend/omni.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Response backend for omni chat models that emit text and audio directly.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncIterator +from typing import Any + +import numpy as np + +from sglang_omni.client import Client, GenerateRequest, Message, SamplingParams +from sglang_omni.client.audio import DEFAULT_SAMPLE_RATE +from sglang_omni.realtime.backend.base import ( + BackendCapabilities, + ResponseBackend, + ResponseEvent, + TurnContext, +) + + +class OmniResponseBackend(ResponseBackend): + """Realtime backend backed by the existing request-oriented omni client.""" + + def __init__( + self, + *, + client: Client, + model: str, + max_new_tokens: int = 256, + output_modalities: tuple[str, ...] = ("text", "audio"), + ) -> None: + self._client = client + self._model = model + self._max_new_tokens = max_new_tokens + self._output_modalities = output_modalities + self._capabilities = BackendCapabilities( + accepts_audio_input=True, + accepts_video_input=True, + returns_text="text" in output_modalities, + returns_audio="audio" in output_modalities, + supports_cancel=True, + ) + + @property + def model_name(self) -> str: + return self._model + + @property + def capabilities(self) -> BackendCapabilities: + return self._capabilities + + async def stream_response( + self, + turn: TurnContext, + ) -> AsyncIterator[ResponseEvent]: + response_id = uuid.uuid4().hex + yield ResponseEvent(type="response_started", response_id=response_id) + + finish_reason = "stop" + emitted_text = "" + request = self._build_request(turn) + try: + async for chunk in self._client.generate(request, request_id=response_id): + if chunk.finish_reason is not None: + finish_reason = chunk.finish_reason + continue + + if chunk.text: + text_delta = self._coerce_text_delta(chunk.text, emitted_text) + if text_delta: + emitted_text += text_delta + yield ResponseEvent( + type="text_delta", + response_id=response_id, + text=text_delta, + ) + + if chunk.audio_data is not None: + yield ResponseEvent( + type="audio_chunk", + response_id=response_id, + audio=np.asarray(chunk.audio_data), + sample_rate=chunk.sample_rate or DEFAULT_SAMPLE_RATE, + ) + except Exception as exc: + yield ResponseEvent( + type="error", + response_id=response_id, + error=str(exc), + ) + return + + yield ResponseEvent( + type="done", + response_id=response_id, + finish_reason=finish_reason, + ) + + async def cancel(self, response_id: str) -> None: + await self._client.abort(response_id) + + @staticmethod + def _serialize_audio_input(audio: Any) -> dict[str, Any]: + audio_np = np.asarray(audio, dtype=np.float32) + return { + "audio_waveform": audio_np.tobytes(), + "audio_waveform_shape": list(audio_np.shape), + "audio_waveform_dtype": "float32", + } + + @staticmethod + def _coerce_text_delta(text: str, emitted_text: str) -> str: + if not text: + return "" + if not emitted_text: + return text + if text == emitted_text: + return "" + if text.startswith(emitted_text): + return text[len(emitted_text) :] + return text + + def _build_request(self, turn: TurnContext) -> GenerateRequest: + messages: list[Message] = [] + if turn.instructions: + messages.append(Message(role="system", content=turn.instructions)) + messages.extend( + Message(role=item["role"], content=item["content"]) for item in turn.history + ) + messages.append(Message(role="user", content=turn.user_text or " ")) + + metadata: dict[str, Any] = {} + if turn.user_audio is not None: + metadata["audios"] = [self._serialize_audio_input(turn.user_audio)] + if turn.user_audio_sample_rate is not None: + metadata["audio_target_sr"] = int(turn.user_audio_sample_rate) + if turn.recent_video is not None: + metadata["videos"] = [turn.recent_video] + metadata["video_fps"] = turn.recent_video_fps + + return GenerateRequest( + model=self._model, + messages=messages, + sampling=SamplingParams(max_new_tokens=self._max_new_tokens), + stream=True, + output_modalities=list(self._output_modalities), + metadata=metadata, + ) diff --git a/sglang_omni/realtime/media.py b/sglang_omni/realtime/media.py new file mode 100644 index 000000000..e02687825 --- /dev/null +++ b/sglang_omni/realtime/media.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Realtime media helpers shared across realtime transports.""" + +from __future__ import annotations + +import numpy as np +from PIL import Image + + +def mono_float32(audio: Any) -> np.ndarray: + """Normalize arbitrary mono/stereo audio into mono float32 in [-1, 1].""" + arr = np.asarray(audio) + if arr.ndim == 0: + arr = arr.reshape(1) + + if np.issubdtype(arr.dtype, np.integer): + scale = max(abs(np.iinfo(arr.dtype).min), np.iinfo(arr.dtype).max) + arr = arr.astype(np.float32) / float(scale) + else: + arr = arr.astype(np.float32, copy=False) + + if arr.ndim > 1: + if arr.shape[0] <= arr.shape[-1]: + arr = arr.mean(axis=0) + else: + arr = arr.mean(axis=1) + + return np.clip(arr, -1.0, 1.0) + + +def resample_linear( + audio: np.ndarray, + orig_sr: int, + target_sr: int, +) -> np.ndarray: + """Resample 1D audio with linear interpolation.""" + if orig_sr == target_sr: + return audio.astype(np.float32, copy=False) + if audio.size == 0: + return audio.astype(np.float32, copy=False) + + duration = audio.shape[0] / float(orig_sr) + new_len = max(int(round(duration * target_sr)), 1) + old_idx = np.arange(audio.shape[0], dtype=np.float64) + new_idx = np.linspace(0.0, audio.shape[0] - 1, num=new_len, dtype=np.float64) + return np.interp(new_idx, old_idx, audio).astype(np.float32) + + +def resize_rgb_frame( + frame_rgb: np.ndarray, + *, + width: int, + height: int, +) -> np.ndarray: + """Resize an HWC RGB frame to a bounded size on CPU.""" + image = Image.fromarray(frame_rgb.astype(np.uint8), mode="RGB") + resized = image.resize((width, height), Image.Resampling.BILINEAR) + return np.asarray(resized, dtype=np.uint8) diff --git a/sglang_omni/realtime/session.py b/sglang_omni/realtime/session.py new file mode 100644 index 000000000..dd6811c72 --- /dev/null +++ b/sglang_omni/realtime/session.py @@ -0,0 +1,645 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Realtime session orchestration shared across interactive transports.""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import torch + +from sglang_omni.realtime.backend import ResponseBackend, TurnContext +from sglang_omni.realtime.media import mono_float32, resample_linear, resize_rgb_frame +from sglang_omni.realtime.utils import throttle +from sglang_omni.realtime.vad import EnergyVad, VadConfig + + +@dataclass +class VideoFrameSample: + ts_monotonic: float + frame_rgb: np.ndarray + + +@dataclass +class VideoBufferConfig: + ingest_fps: float = 2.0 + clip_window_s: float = 4.0 + max_buffer_s: float = 8.0 + max_frames: int = 16 + resize_width: int = 224 + resize_height: int = 224 + + +@dataclass +class VideoBufferState: + config: VideoBufferConfig = field(default_factory=VideoBufferConfig) + frames: deque[VideoFrameSample] = field(default_factory=deque) + last_ingest_ts: float | None = None + total_frames_received: int = 0 + + +@dataclass +class AudioTurnState: + sample_rate: int + chunks: list[np.ndarray] = field(default_factory=list) + speech_start_ts: float | None = None + speech_end_ts: float | None = None + + +@dataclass +class PendingTurn: + audio: np.ndarray | None + sample_rate: int | None + user_text: str | None + speech_end_ts: float | None + turn_index: int + + +@dataclass +class RealtimeSessionConfig: + instructions: str = ( + "You are a concise, natural voice assistant. Answer conversationally." + ) + input_audio_sample_rate: int = 16000 + input_audio_mode: str = "vad" + vad: VadConfig = field(default_factory=VadConfig) + video: VideoBufferConfig = field(default_factory=VideoBufferConfig) + + +class RealtimeSession: + """Conversation/session state above the request-oriented pipeline.""" + + def __init__( + self, + *, + session_id: str, + backend: ResponseBackend, + output_track: Any, + config: RealtimeSessionConfig, + ) -> None: + self.session_id = session_id + self.backend = backend + self.output_track = output_track + self.config = config + + self.instructions = config.instructions.strip() + self.history: list[dict[str, str]] = [] + self.current_user_text: str | None = None + self.current_audio = AudioTurnState(sample_rate=config.input_audio_sample_rate) + self.video = VideoBufferState(config=config.video) + self._audio_chunk_count = 0 + self.turn_mode = "vad" + self.manual_recording = False + + self.vad = EnergyVad(config.vad) + self._preroll_chunks: deque[np.ndarray] = deque() + self._preroll_samples = 0 + self._event_channel: Any | None = None + self._event_backlog: list[dict[str, Any]] = [] + self._response_lock = asyncio.Lock() + self._closed = False + self._throttle_state: dict[str, float] = {} + self._turn_index = 0 + self._queued_pending_turn: PendingTurn | None = None + + self.active_response_id: str | None = None + self.active_task: asyncio.Task[None] | None = None + self.assistant_playing = False + self._set_turn_mode(config.input_audio_mode) + + async def emit_event(self, event_type: str, **payload: Any) -> None: + event = {"type": event_type, "session_id": self.session_id, **payload} + channel = self._event_channel + if channel is None or getattr(channel, "readyState", None) != "open": + self._event_backlog.append(event) + return + try: + channel.send(json.dumps(event)) + except Exception: + self._event_backlog.append(event) + + def attach_event_channel(self, channel: Any) -> None: + self._event_channel = channel + self._flush_event_backlog() + + def _flush_event_backlog(self) -> None: + channel = self._event_channel + if channel is None or getattr(channel, "readyState", None) != "open": + return + pending = list(self._event_backlog) + self._event_backlog.clear() + for event in pending: + try: + channel.send(json.dumps(event)) + except Exception: + self._event_backlog.append(event) + break + + async def handle_client_event(self, payload: dict[str, Any]) -> None: + event_type = str(payload.get("type") or "").strip() + if event_type == "session.update": + session = payload.get("session") or {} + updated_session: dict[str, Any] = {} + instructions = session.get("instructions") or payload.get("instructions") + if isinstance(instructions, str): + self.instructions = instructions.strip() + updated_session["instructions"] = self.instructions + audio_config = session.get("audio") or {} + input_mode = ( + audio_config.get("input_mode") + or audio_config.get("turn_mode") + or payload.get("input_audio_mode") + ) + if isinstance(input_mode, str) and self._set_turn_mode(input_mode): + updated_session["audio"] = {"input_mode": self.turn_mode} + if updated_session: + await self.emit_event("session.updated", session=updated_session) + return + + if event_type == "input_audio_buffer.start": + if self._closed: + return + if self.active_task is not None or self.assistant_playing: + await self._interrupt_active_response(reason="barge_in") + self.turn_mode = "manual" + self.manual_recording = True + self.current_audio = AudioTurnState( + sample_rate=self.config.input_audio_sample_rate, + speech_start_ts=time.monotonic(), + ) + self._preroll_chunks.clear() + self._preroll_samples = 0 + await self.emit_event("input_audio_buffer.manual_started") + return + + if event_type == "input_audio_buffer.commit": + if self.turn_mode != "manual": + return + self.manual_recording = False + if not self.current_audio.chunks: + await self.emit_event("input_audio_buffer.manual_committed", empty=True) + self.current_audio = AudioTurnState( + sample_rate=self.config.input_audio_sample_rate + ) + return + + self.current_audio.speech_end_ts = time.monotonic() + await self.emit_event( + "input_audio_buffer.manual_committed", + empty=False, + sample_count=int( + sum(chunk.size for chunk in self.current_audio.chunks) + ), + ) + pending = self._consume_pending_turn() + if pending is not None: + await self._start_or_queue_response(pending) + return + + if event_type == "conversation.item.create": + item = payload.get("item") or {} + if item.get("role") == "user" and isinstance(item.get("content"), str): + self.current_user_text = item["content"] + await self.emit_event("conversation.item.created", item=item) + return + + if event_type == "response.create": + if self._closed: + return + if self.active_task is not None or self.assistant_playing: + await self._interrupt_active_response(reason="barge_in") + pending = self._consume_pending_turn() + if pending is not None: + await self._start_or_queue_response(pending) + return + + if event_type == "response.cancel" and ( + self.active_response_id is not None or self.active_task is not None + ): + await self._interrupt_active_response(reason="client") + + async def close(self) -> None: + self._closed = True + self._queued_pending_turn = None + if ( + self.active_response_id is not None + and self.backend.capabilities.supports_cancel + ): + await self.backend.cancel(self.active_response_id) + if self.active_task is not None: + self.active_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self.active_task + await self.output_track.clear() + + async def handle_audio_chunk( + self, + audio: np.ndarray, + sample_rate: int, + *, + timestamp: float | None = None, + ) -> None: + if self._closed: + return + + ts = timestamp if timestamp is not None else time.monotonic() + chunk = mono_float32(audio) + if sample_rate != self.config.input_audio_sample_rate: + chunk = resample_linear( + chunk, + sample_rate, + self.config.input_audio_sample_rate, + ) + if chunk.size == 0: + return + + self._audio_chunk_count += 1 + if self.turn_mode == "manual": + if not self.manual_recording: + return + self.current_audio.chunks.append(chunk) + rms = self.vad.measure_level(chunk) + dc_offset = float(np.mean(chunk)) if chunk.size else 0.0 + await self._emit_audio_chunk_received( + timestamp=ts, + chunk_count=self._audio_chunk_count, + sample_count=int(chunk.size), + sample_rate=int(self.config.input_audio_sample_rate), + rms=rms, + dc_offset=dc_offset, + frame_count=0, + voiced_frame_count=0, + speech_ratio=0.0, + speaking_before=True, + speaking_after=True, + ) + return + + was_speaking = self.vad.speaking + if not was_speaking: + self._append_preroll(chunk) + + rms = self.vad.measure_level(chunk) + dc_offset = float(np.mean(chunk)) if chunk.size else 0.0 + event = self.vad.process(chunk) + await self._emit_audio_chunk_received( + timestamp=ts, + chunk_count=self._audio_chunk_count, + sample_count=int(chunk.size), + sample_rate=int(self.config.input_audio_sample_rate), + rms=rms, + dc_offset=dc_offset, + frame_count=int(getattr(self.vad, "last_frame_count", 0)), + voiced_frame_count=int(getattr(self.vad, "last_voiced_frame_count", 0)), + speech_ratio=float(getattr(self.vad, "last_speech_ratio", 0.0)), + speaking_before=bool(was_speaking), + speaking_after=bool(self.vad.speaking), + ) + if event.speech_started: + if self.active_task is not None or self.assistant_playing: + await self._interrupt_active_response(reason="barge_in") + self.current_audio = AudioTurnState( + sample_rate=self.config.input_audio_sample_rate, + chunks=list(self._preroll_chunks), + speech_start_ts=ts, + ) + self._preroll_chunks.clear() + self._preroll_samples = 0 + await self.emit_event("input_audio_buffer.speech_started") + elif was_speaking: + self.current_audio.chunks.append(chunk) + + if event.speech_stopped and self.current_audio.chunks: + self.current_audio.speech_end_ts = ts + await self.emit_event("input_audio_buffer.speech_stopped") + pending = self._consume_pending_turn() + if pending is not None: + await self._start_or_queue_response(pending) + + async def handle_video_frame( + self, + frame_rgb: np.ndarray, + *, + timestamp: float | None = None, + ) -> None: + if self._closed: + return + ts = timestamp if timestamp is not None else time.monotonic() + cfg = self.video.config + if ( + self.video.last_ingest_ts is not None + and cfg.ingest_fps > 0 + and (ts - self.video.last_ingest_ts) < (1.0 / cfg.ingest_fps) + ): + return + + resized = resize_rgb_frame( + frame_rgb, + width=cfg.resize_width, + height=cfg.resize_height, + ) + self.video.frames.append(VideoFrameSample(ts_monotonic=ts, frame_rgb=resized)) + self.video.last_ingest_ts = ts + self.video.total_frames_received += 1 + + min_allowed_ts = ts - cfg.max_buffer_s + while self.video.frames and self.video.frames[0].ts_monotonic < min_allowed_ts: + self.video.frames.popleft() + while len(self.video.frames) > cfg.max_frames: + self.video.frames.popleft() + + await self._emit_video_frame_received( + timestamp=ts, + frame_count=self.video.total_frames_received, + buffered_frames=len(self.video.frames), + width=int(resized.shape[1]), + height=int(resized.shape[0]), + ) + + def _append_preroll(self, chunk: np.ndarray) -> None: + max_samples = int( + round(self.config.vad.preroll_s * self.config.input_audio_sample_rate) + ) + self._preroll_chunks.append(chunk) + self._preroll_samples += chunk.shape[0] + while self._preroll_chunks and self._preroll_samples > max_samples: + removed = self._preroll_chunks.popleft() + self._preroll_samples -= removed.shape[0] + + def _consume_pending_turn(self) -> PendingTurn | None: + user_text = self.current_user_text.strip() if self.current_user_text else None + if not self.current_audio.chunks and not user_text: + return None + audio = None + sample_rate: int | None = None + if self.current_audio.chunks: + audio = np.concatenate(self.current_audio.chunks, axis=0) + sample_rate = self.current_audio.sample_rate + self._turn_index += 1 + pending = PendingTurn( + audio=audio, + sample_rate=sample_rate, + user_text=user_text, + speech_end_ts=self.current_audio.speech_end_ts, + turn_index=self._turn_index, + ) + self.current_audio = AudioTurnState( + sample_rate=self.config.input_audio_sample_rate + ) + self.current_user_text = None + return pending + + def _set_turn_mode(self, mode: str) -> bool: + normalized = str(mode).strip().lower() + if normalized not in {"vad", "manual"}: + return False + self.turn_mode = normalized + self.manual_recording = False + self.current_audio = AudioTurnState( + sample_rate=self.config.input_audio_sample_rate + ) + self._preroll_chunks.clear() + self._preroll_samples = 0 + reset_vad = getattr(self.vad, "reset", None) + if callable(reset_vad): + reset_vad() + else: + self.vad.speaking = False + return True + + async def _start_or_queue_response(self, pending: PendingTurn) -> None: + if self.active_task is None: + self.assistant_playing = True + self.active_task = asyncio.create_task(self._run_response(pending)) + return + self._queued_pending_turn = pending + + async def _interrupt_active_response(self, *, reason: str) -> None: + had_active_response = ( + self.active_task is not None + or self.active_response_id is not None + or self.assistant_playing + ) + if not had_active_response: + return + + response_id = self.active_response_id + task = self.active_task + if response_id is not None and self.backend.capabilities.supports_cancel: + await self.backend.cancel(response_id) + elif task is not None: + task.cancel() + + await self.output_track.clear() + self.assistant_playing = False + await self.emit_event( + "response.cancelled", + response_id=response_id, + reason=reason, + ) + + @throttle(1.0, timestamp_kw="timestamp") + async def _emit_audio_chunk_received( + self, + *, + timestamp: float, + chunk_count: int, + sample_count: int, + sample_rate: int, + rms: float, + dc_offset: float, + frame_count: int, + voiced_frame_count: int, + speech_ratio: float, + speaking_before: bool, + speaking_after: bool, + ) -> None: + await self.emit_event( + "input_audio_buffer.chunk_received", + chunk_count=chunk_count, + sample_count=sample_count, + sample_rate=sample_rate, + rms=rms, + dc_offset=dc_offset, + frame_count=frame_count, + voiced_frame_count=voiced_frame_count, + speech_ratio=speech_ratio, + speaking_before=speaking_before, + speaking_after=speaking_after, + ) + + @throttle(1.0, timestamp_kw="timestamp") + async def _emit_video_frame_received( + self, + *, + timestamp: float, + frame_count: int, + buffered_frames: int, + width: int, + height: int, + ) -> None: + await self.emit_event( + "input_video_buffer.frame_received", + frame_count=frame_count, + buffered_frames=buffered_frames, + width=width, + height=height, + ) + + def sample_recent_video_clip( + self, + *, + anchor_ts: float | None, + ) -> tuple[torch.Tensor | None, float | None]: + if anchor_ts is None: + return None, None + + window_start = anchor_ts - self.video.config.clip_window_s + frames = [ + item.frame_rgb + for item in self.video.frames + if window_start <= item.ts_monotonic <= anchor_ts + ] + if not frames: + return None, None + + clip = torch.stack( + [ + torch.from_numpy(np.array(frame, copy=True)).permute(2, 0, 1) + for frame in frames + ], + dim=0, + ) + fps = len(frames) / max(self.video.config.clip_window_s, 1e-6) + return clip, fps + + async def _run_response(self, pending: PendingTurn) -> None: + async with self._response_lock: + self.assistant_playing = True + finish_reason = "stop" + assistant_text_parts: list[str] = [] + + await self.output_track.clear() + clip, fps = self.sample_recent_video_clip(anchor_ts=pending.speech_end_ts) + video_frame_count = int(clip.shape[0]) if clip is not None else 0 + await self.emit_event( + "turn.prepared", + audio_sample_count=( + int(pending.audio.size) if pending.audio is not None else 0 + ), + audio_sample_rate=( + int(pending.sample_rate) + if pending.sample_rate is not None + else None + ), + video_frame_count=video_frame_count, + video_fps=float(fps) if fps is not None else None, + ) + turn = TurnContext( + session_id=self.session_id, + history=list(self.history), + instructions=self.instructions, + user_text=pending.user_text, + user_audio=pending.audio, + user_audio_sample_rate=pending.sample_rate, + recent_video=clip, + recent_video_fps=fps, + turn_index=pending.turn_index, + ) + + try: + async for event in self.backend.stream_response(turn): + if event.type == "response_started": + self.active_response_id = event.response_id + await self.emit_event( + "response.created", + response_id=event.response_id, + model=self.backend.model_name, + ) + continue + + if event.type == "text_delta" and event.text: + response_id = event.response_id + assistant_text_parts.append(event.text) + await self.emit_event( + "response.output_text.delta", + response_id=response_id, + delta=event.text, + ) + continue + + if event.type == "audio_chunk" and event.audio is not None: + response_id = event.response_id + sample_rate = ( + event.sample_rate or self.config.input_audio_sample_rate + ) + audio_np = np.asarray(event.audio) + await self.output_track.enqueue(audio_np, sample_rate) + await self.emit_event( + "response.output_audio.delta", + response_id=response_id, + sample_rate=sample_rate, + sample_count=int(audio_np.size), + ) + continue + + if event.type == "done": + finish_reason = event.finish_reason or "stop" + continue + + if event.type == "error": + await self.emit_event( + "error", + error={ + "message": event.error or "Unknown backend error", + "response_id": event.response_id, + }, + ) + return + except asyncio.CancelledError: + raise + except Exception as exc: + await self.emit_event( + "error", + error={ + "message": str(exc), + "response_id": self.active_response_id, + }, + ) + else: + assistant_text = "".join(assistant_text_parts).strip() + if pending.user_text: + self.history.append({"role": "user", "content": pending.user_text}) + if assistant_text: + self.history.append( + {"role": "assistant", "content": assistant_text} + ) + await self.emit_event( + "response.done", + response_id=self.active_response_id, + finish_reason=finish_reason, + text=assistant_text, + ) + finally: + drain_deadline = time.monotonic() + 10.0 + while ( + getattr(self.output_track, "pending_samples", 0) > 0 + and time.monotonic() < drain_deadline + ): + await asyncio.sleep(0.05) + self.assistant_playing = False + self.active_response_id = None + self.active_task = None + queued_pending = self._queued_pending_turn + self._queued_pending_turn = None + if queued_pending is not None and not self._closed: + self.assistant_playing = True + self.active_task = asyncio.create_task( + self._run_response(queued_pending) + ) diff --git a/sglang_omni/realtime/utils.py b/sglang_omni/realtime/utils.py new file mode 100644 index 000000000..dccbd2520 --- /dev/null +++ b/sglang_omni/realtime/utils.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Small realtime utilities.""" + +from __future__ import annotations + +import functools +import inspect +import time +from typing import Any, Callable, TypeVar + +F = TypeVar("F", bound=Callable[..., Any]) + + +def throttle( + interval_s: float, + *, + timestamp_kw: str | None = None, + state_attr: str = "_throttle_state", +) -> Callable[[F], F]: + """Throttle instance method calls to at most once per interval. + + The decorated method must be called on an object instance. Per-instance + throttle state is stored on ``state_attr`` as a dictionary keyed by the + wrapped method name. + """ + + def decorator(func: F) -> F: + key = func.__qualname__ + is_async = inspect.iscoroutinefunction(func) + + def resolve_timestamp(args: tuple[Any, ...], kwargs: dict[str, Any]) -> float: + if timestamp_kw is not None: + value = kwargs.get(timestamp_kw) + if value is not None: + return float(value) + return time.monotonic() + + def should_run(instance: Any, ts: float) -> bool: + state = getattr(instance, state_attr, None) + if state is None: + state = {} + setattr(instance, state_attr, state) + last_ts = state.get(key) + if last_ts is not None and (ts - float(last_ts)) < interval_s: + return False + state[key] = ts + return True + + if is_async: + + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + if not args: + raise TypeError( + "throttle-decorated methods must be bound to an instance" + ) + ts = resolve_timestamp(args, kwargs) + if not should_run(args[0], ts): + return None + return await func(*args, **kwargs) + + return async_wrapper # type: ignore[return-value] + + @functools.wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + if not args: + raise TypeError( + "throttle-decorated methods must be bound to an instance" + ) + ts = resolve_timestamp(args, kwargs) + if not should_run(args[0], ts): + return None + return func(*args, **kwargs) + + return sync_wrapper # type: ignore[return-value] + + return decorator diff --git a/sglang_omni/realtime/vad.py b/sglang_omni/realtime/vad.py new file mode 100644 index 000000000..2ef724322 --- /dev/null +++ b/sglang_omni/realtime/vad.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 +"""webrtcvad-based voice activity detection for realtime sessions.""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from math import ceil + +import numpy as np + +try: + import webrtcvad +except ImportError: # pragma: no cover - surfaced at runtime + webrtcvad = None + + +@dataclass +class VadConfig: + sample_rate: int = 16000 + aggressiveness: int = 3 + frame_duration_ms: int = 20 + min_speech_s: float = 0.25 + min_silence_s: float = 0.60 + preroll_s: float = 0.18 + # Legacy fields kept for request compatibility; the VAD backend does not use them. + start_threshold: float = 0.020 + stop_threshold: float = 0.012 + start_margin: float = 0.020 + stop_margin: float = 0.005 + bootstrap_s: float = 0.50 + noise_floor_alpha: float = 0.20 + + +@dataclass +class VadEvent: + speech_started: bool = False + speech_stopped: bool = False + + +class EnergyVad: + """Stateful wrapper around webrtcvad with start/stop hysteresis.""" + + def __init__(self, config: VadConfig | None = None) -> None: + self.config = config or VadConfig() + if webrtcvad is None: + raise RuntimeError( + "Realtime VAD now depends on webrtcvad. " + "Install the project with the realtime extra." + ) + + if self.config.sample_rate not in {8000, 16000, 32000, 48000}: + raise ValueError( + f"Unsupported VAD sample rate {self.config.sample_rate}; " + "expected one of 8000, 16000, 32000, 48000." + ) + if self.config.frame_duration_ms not in {10, 20, 30}: + raise ValueError( + f"Unsupported VAD frame duration {self.config.frame_duration_ms} ms; " + "expected one of 10, 20, 30." + ) + + self._vad = webrtcvad.Vad(int(np.clip(self.config.aggressiveness, 0, 3))) + self._frame_samples = ( + self.config.sample_rate * self.config.frame_duration_ms + ) // 1000 + self._frame_duration_s = self.config.frame_duration_ms / 1000.0 + self._start_window_frames = max( + 1, + int(round(self.config.min_speech_s / self._frame_duration_s)), + ) + self._stop_window_frames = max( + 1, + int(round(self.config.min_silence_s / self._frame_duration_s)), + ) + self._start_required_frames = max( + 1, + ceil(self._start_window_frames * 0.6), + ) + self._stop_required_unvoiced_frames = max( + 1, + ceil(self._stop_window_frames * 0.5), + ) + + self.speaking = False + self._frame_tail = np.zeros(0, dtype=np.int16) + self._recent_votes: deque[bool] = deque() + self._last_frame_count = 0 + self._last_voiced_frame_count = 0 + + @staticmethod + def measure_level(audio: np.ndarray) -> float: + audio = np.asarray(audio, dtype=np.float32).reshape(-1) + if audio.size == 0: + return 0.0 + centered = audio - float(np.mean(audio)) + return float(np.sqrt(np.mean(np.square(centered)))) + + @property + def noise_floor(self) -> float: + return 0.0 + + def effective_start_threshold(self) -> float: + return 0.0 + + def effective_stop_threshold(self) -> float: + return 0.0 + + @property + def last_frame_count(self) -> int: + return self._last_frame_count + + @property + def last_voiced_frame_count(self) -> int: + return self._last_voiced_frame_count + + @property + def last_speech_ratio(self) -> float: + if self._last_frame_count <= 0: + return 0.0 + return float(self._last_voiced_frame_count / self._last_frame_count) + + def _detect_frame(self, pcm_frame: np.ndarray) -> bool: + return bool(self._vad.is_speech(pcm_frame.tobytes(), self.config.sample_rate)) + + def _append_vote(self, vote: bool, *, speaking: bool) -> None: + window = self._stop_window_frames if speaking else self._start_window_frames + self._recent_votes.append(vote) + while len(self._recent_votes) > window: + self._recent_votes.popleft() + + def reset(self) -> None: + self.speaking = False + self._frame_tail = np.zeros(0, dtype=np.int16) + self._recent_votes.clear() + self._last_frame_count = 0 + self._last_voiced_frame_count = 0 + + def process(self, audio: np.ndarray) -> VadEvent: + audio = np.asarray(audio, dtype=np.float32).reshape(-1) + if audio.size == 0: + self._last_frame_count = 0 + self._last_voiced_frame_count = 0 + return VadEvent() + + pcm = np.clip(audio, -1.0, 1.0) + pcm = (pcm * 32767.0).astype(np.int16, copy=False) + if self._frame_tail.size: + pcm = np.concatenate([self._frame_tail, pcm]) + + total_frames = int(pcm.size // self._frame_samples) + if total_frames <= 0: + self._frame_tail = pcm + self._last_frame_count = 0 + self._last_voiced_frame_count = 0 + return VadEvent() + + event = VadEvent() + voiced_frames = 0 + + for index in range(total_frames): + start = index * self._frame_samples + frame = pcm[start : start + self._frame_samples] + is_voiced = self._detect_frame(frame) + voiced_frames += int(is_voiced) + self._append_vote(is_voiced, speaking=self.speaking) + + if not self.speaking: + if ( + len(self._recent_votes) >= self._start_window_frames + and sum(self._recent_votes) >= self._start_required_frames + ): + self.speaking = True + self._recent_votes.clear() + event.speech_started = True + continue + + if self.speaking: + unvoiced_frames = len(self._recent_votes) - sum(self._recent_votes) + if ( + len(self._recent_votes) >= self._stop_window_frames + and unvoiced_frames >= self._stop_required_unvoiced_frames + ): + self.speaking = False + self._recent_votes.clear() + event.speech_stopped = True + + consumed = total_frames * self._frame_samples + self._frame_tail = pcm[consumed:] + self._last_frame_count = total_frames + self._last_voiced_frame_count = voiced_frames + return event diff --git a/sglang_omni/serve/openai_api.py b/sglang_omni/serve/openai_api.py index 81428ce9e..df09896e3 100644 --- a/sglang_omni/serve/openai_api.py +++ b/sglang_omni/serve/openai_api.py @@ -90,6 +90,14 @@ def create_app( _register_models(app) _register_chat_completions(app) _register_speech(app) + try: + from sglang_omni.serve.realtime_ws_api import create_realtime_ws_router + + app.include_router( + create_realtime_ws_router(client, model_name=app.state.model_name) + ) + except Exception: + logger.exception("Failed to register realtime routes") return app diff --git a/sglang_omni/serve/realtime_ws_api.py b/sglang_omni/serve/realtime_ws_api.py new file mode 100644 index 000000000..6dfbdc515 --- /dev/null +++ b/sglang_omni/serve/realtime_ws_api.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +"""WebSocket transport for realtime sessions.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +from collections.abc import Callable +from dataclasses import dataclass + +import numpy as np +from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from pydantic import BaseModel + +from sglang_omni.client import Client +from sglang_omni.realtime.backend import OmniResponseBackend, ResponseBackend +from sglang_omni.realtime.media import mono_float32, resample_linear +from sglang_omni.realtime.session import RealtimeSession, RealtimeSessionConfig +from sglang_omni.realtime.vad import VadConfig + +logger = logging.getLogger(__name__) + +BackendFactory = Callable[[str, int], ResponseBackend] + + +class RealtimeVadRequest(BaseModel): + aggressiveness: int = 3 + frame_duration_ms: int = 20 + min_speech_s: float = 0.25 + min_silence_s: float = 0.60 + preroll_s: float = 0.18 + # Legacy energy-VAD fields retained for compatibility with older clients. + start_threshold: float = 0.020 + stop_threshold: float = 0.012 + + +class WebSocketEventChannel: + """Queue-backed event channel with the subset RealtimeSession expects.""" + + def __init__(self, send_text: Callable[[str], None]) -> None: + self._send_text = send_text + self.readyState = "open" + + def send(self, message: str) -> None: + if self.readyState != "open": + raise RuntimeError("WebSocket event channel is closed") + self._send_text(message) + + def close(self) -> None: + self.readyState = "closed" + + +class WebSocketAudioOutputSink: + """PCM sink that streams assistant audio back over the websocket.""" + + def __init__( + self, + *, + sample_rate: int, + send_bytes: Callable[[bytes], None], + send_text: Callable[[str], None], + session_id: str, + ) -> None: + self.sample_rate = int(sample_rate) + self._send_bytes = send_bytes + self._send_text = send_text + self._session_id = session_id + self._pending_samples = 0 + self._pending_updated_at: float | None = None + + @property + def pending_samples(self) -> int: + return self._drain_pending() + + def _drain_pending(self) -> int: + now = time.monotonic() + if self._pending_updated_at is None: + return 0 + elapsed = max(now - self._pending_updated_at, 0.0) + drained = int(round(elapsed * self.sample_rate)) + if drained <= 0: + return int(self._pending_samples) + self._pending_samples = max(int(self._pending_samples) - drained, 0) + self._pending_updated_at = now if self._pending_samples > 0 else None + return int(self._pending_samples) + + async def clear(self) -> None: + self._pending_samples = 0 + self._pending_updated_at = None + self._send_text( + json.dumps( + { + "type": "output_audio_buffer.cleared", + "session_id": self._session_id, + } + ) + ) + + async def enqueue(self, audio: np.ndarray, sample_rate: int) -> None: + pcm = mono_float32(audio) + pcm = resample_linear(pcm, sample_rate, self.sample_rate) + if pcm.size == 0: + return + + pcm_i16 = np.clip(pcm * 32767.0, -32768.0, 32767.0).astype(" APIRouter: + if backend_factory is None: + if client is None: + raise ValueError( + "create_realtime_ws_router requires either client or backend_factory" + ) + + def backend_factory( + resolved_model: str, + max_new_tokens: int, + ) -> ResponseBackend: + return OmniResponseBackend( + client=client, + model=resolved_model, + max_new_tokens=max_new_tokens, + output_modalities=("text", "audio"), + ) + + router = APIRouter() + + @router.websocket("/v1/realtime/ws") + async def realtime_ws(websocket: WebSocket) -> None: + await websocket.accept() + + send_queue: asyncio.Queue[tuple[str, str | bytes]] = asyncio.Queue() + session_id = uuid.uuid4().hex + + def send_text_nowait(payload: str) -> None: + send_queue.put_nowait(("text", payload)) + + def send_bytes_nowait(payload: bytes) -> None: + send_queue.put_nowait(("bytes", payload)) + + async def sender() -> None: + while True: + kind, payload = await send_queue.get() + if kind == "text": + await websocket.send_text(str(payload)) + else: + await websocket.send_bytes(bytes(payload)) + + sender_task = asyncio.create_task(sender()) + + model = websocket.query_params.get("model") or model_name + instructions = websocket.query_params.get("instructions") + input_audio_mode = websocket.query_params.get("input_audio_mode") or "vad" + max_new_tokens_raw = websocket.query_params.get("max_new_tokens") or "256" + try: + max_new_tokens = max(int(max_new_tokens_raw), 1) + except ValueError: + max_new_tokens = 256 + + vad_config = VadConfig() + vad_raw = websocket.query_params.get("vad") + if vad_raw: + try: + vad = RealtimeVadRequest.model_validate_json(vad_raw) + except Exception: + logger.warning("Ignoring invalid websocket VAD config: %s", vad_raw) + else: + vad_config = VadConfig( + sample_rate=vad_config.sample_rate, + aggressiveness=vad.aggressiveness, + frame_duration_ms=vad.frame_duration_ms, + min_speech_s=vad.min_speech_s, + min_silence_s=vad.min_silence_s, + preroll_s=vad.preroll_s, + start_threshold=vad.start_threshold, + stop_threshold=vad.stop_threshold, + ) + + output_sink = WebSocketAudioOutputSink( + sample_rate=24000, + send_bytes=send_bytes_nowait, + send_text=send_text_nowait, + session_id=session_id, + ) + event_channel = WebSocketEventChannel(send_text_nowait) + backend = backend_factory(model, max_new_tokens) + session = RealtimeSession( + session_id=session_id, + backend=backend, + output_track=output_sink, + config=RealtimeSessionConfig( + instructions=( + instructions + or "You are a concise, natural voice assistant. Answer conversationally." + ), + input_audio_mode=input_audio_mode, + vad=vad_config, + ), + ) + handle = WebSocketSessionHandle( + session=session, + event_channel=event_channel, + output_sink=output_sink, + ) + session.attach_event_channel(event_channel) + + await session.emit_event( + "session.created", + model=session.backend.model_name, + instructions=session.instructions, + audio={ + "input_mode": session.turn_mode, + "input_encoding": "pcm16le", + "output_encoding": "pcm16le", + "output_sample_rate": output_sink.sample_rate, + }, + transport={"type": "websocket"}, + ) + + try: + while True: + message = await websocket.receive() + message_type = message.get("type") + if message_type == "websocket.disconnect": + break + + payload_text = message.get("text") + if payload_text is not None: + try: + payload = json.loads(payload_text) + except json.JSONDecodeError: + await session.emit_event( + "error", + error={ + "message": "Invalid JSON control message", + "session_id": session.session_id, + }, + ) + continue + + event_type = str(payload.get("type") or "").strip() + if event_type == "input_audio_format": + sample_rate = payload.get("sample_rate") + if isinstance(sample_rate, int) and sample_rate > 0: + handle.input_sample_rate = int(sample_rate) + await session.emit_event( + "input_audio_format.updated", + sample_rate=handle.input_sample_rate, + encoding="pcm16le", + ) + continue + + await session.handle_client_event(payload) + continue + + payload_bytes = message.get("bytes") + if payload_bytes is None: + continue + if not payload_bytes: + continue + + if len(payload_bytes) % 2 != 0: + payload_bytes = payload_bytes[:-1] + if not payload_bytes: + continue + + audio = np.frombuffer(payload_bytes, dtype=" None: + self.submissions: list[tuple[str, str, object]] = [] + self.abort_messages: list[object] = [] + + async def submit_to_stage(self, stage_name: str, endpoint: str, msg: object) -> None: + self.submissions.append((stage_name, endpoint, msg)) + + async def broadcast_abort(self, msg: object) -> None: + self.abort_messages.append(msg) + + +async def _wait_for_stream_registration(coordinator: Coordinator, request_id: str) -> None: + for _ in range(100): + if request_id in coordinator._stream_queues: + return + await asyncio.sleep(0) + raise AssertionError(f"stream queue for {request_id} was not registered") + + +async def _wait_for_abort(control_plane: _FakeControlPlane, request_id: str) -> None: + for _ in range(100): + if [msg.request_id for msg in control_plane.abort_messages] == [request_id]: + return + await asyncio.sleep(0) + raise AssertionError(f"abort for {request_id} was not observed") + + +def _make_coordinator() -> tuple[Coordinator, _FakeControlPlane]: + coordinator = Coordinator( + completion_endpoint="inproc://completion", + abort_endpoint="inproc://abort", + entry_stage="entry", + ) + control_plane = _FakeControlPlane() + coordinator.control_plane = control_plane + coordinator.register_stage("entry", "inproc://entry") + return coordinator, control_plane + + +@pytest.mark.asyncio +async def test_stream_abort_on_early_consumer_exit() -> None: + coordinator, control_plane = _make_coordinator() + request_id = "req-early-exit" + received: list[object] = [] + + async def _consume_one() -> None: + async for msg in coordinator.stream(request_id, OmniRequest(inputs={"text": "hi"})): + received.append(msg) + break + + task = asyncio.create_task(_consume_one()) + await _wait_for_stream_registration(coordinator, request_id) + await coordinator._handle_stream( + StreamMessage( + request_id=request_id, + from_stage="decode", + chunk={"text": "hello"}, + modality="text", + ) + ) + await task + await _wait_for_abort(control_plane, request_id) + + assert len(received) == 1 + assert [msg.request_id for msg in control_plane.abort_messages] == [request_id] + assert request_id not in coordinator._requests + assert request_id not in coordinator._stream_queues + assert request_id not in coordinator._completion_futures + + +@pytest.mark.asyncio +async def test_stream_does_not_abort_after_normal_completion() -> None: + coordinator, control_plane = _make_coordinator() + request_id = "req-complete" + + async def _consume_all() -> list[object]: + return [ + msg + async for msg in coordinator.stream( + request_id, + OmniRequest(inputs={"text": "hi"}), + ) + ] + + task = asyncio.create_task(_consume_all()) + await _wait_for_stream_registration(coordinator, request_id) + await coordinator._handle_stream( + StreamMessage( + request_id=request_id, + from_stage="decode", + chunk={"text": "hello"}, + modality="text", + ) + ) + await coordinator._handle_completion( + CompleteMessage( + request_id=request_id, + from_stage="decode", + success=True, + result={"text": "hello"}, + ) + ) + + received = await task + + assert len(received) == 2 + assert control_plane.abort_messages == [] + assert request_id not in coordinator._requests + assert request_id not in coordinator._stream_queues + assert request_id not in coordinator._completion_futures \ No newline at end of file diff --git a/tests/test_realtime_audio_pipeline.py b/tests/test_realtime_audio_pipeline.py new file mode 100644 index 000000000..a8d60bf6c --- /dev/null +++ b/tests/test_realtime_audio_pipeline.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf + +from sglang_omni.realtime.backend import MockResponseBackend +from sglang_omni.realtime.media import mono_float32, resample_linear +from sglang_omni.realtime.session import RealtimeSession, RealtimeSessionConfig + +TEST_AUDIO_PATH = Path(__file__).resolve().parent / "data" / "query_to_cars.wav" + + +class _CollectingOutputTrack: + def __init__(self) -> None: + self.pending_samples = 0 + self.clear_calls = 0 + self.enqueue_calls: list[tuple[np.ndarray, int]] = [] + + async def clear(self) -> None: + self.clear_calls += 1 + self.pending_samples = 0 + + async def enqueue(self, audio: np.ndarray, sample_rate: int) -> None: + self.enqueue_calls.append((np.asarray(audio), sample_rate)) + self.pending_samples = 0 + + +def _load_test_audio() -> tuple[np.ndarray, int]: + audio, sample_rate = sf.read(TEST_AUDIO_PATH, dtype="int16") + return np.asarray(audio), int(sample_rate) + + +def _iter_audio_chunks(audio_i16: np.ndarray, sample_rate: int) -> list[np.ndarray]: + frame_samples = sample_rate // 50 + chunks: list[np.ndarray] = [] + for start in range(0, audio_i16.shape[0], frame_samples): + chunk = audio_i16[start : start + frame_samples] + if chunk.size == 0: + continue + chunks.append(chunk) + return chunks + + +def _expected_session_audio(audio_i16: np.ndarray, sample_rate: int) -> np.ndarray: + frame_samples = sample_rate // 50 + chunks: list[np.ndarray] = [] + for start in range(0, audio_i16.shape[0], frame_samples): + chunk = audio_i16[start : start + frame_samples] + if chunk.size == 0: + continue + mono = mono_float32(chunk) + chunks.append(resample_linear(mono, sample_rate, 16000)) + return np.concatenate(chunks) + + +@pytest.mark.asyncio +async def test_real_wav_audio_pipeline_round_trips_through_manual_mock_echo(): + audio_i16, sample_rate = _load_test_audio() + backend = MockResponseBackend( + audio_mode="echo", + output_modalities=("audio",), + inter_chunk_delay_s=0.0, + chunk_duration_s=0.1, + ) + output_track = _CollectingOutputTrack() + session = RealtimeSession( + session_id="session-audio-pipeline", + backend=backend, + output_track=output_track, + config=RealtimeSessionConfig(), + ) + + await session.handle_client_event({"type": "input_audio_buffer.start"}) + for chunk in _iter_audio_chunks(audio_i16, sample_rate): + await session.handle_audio_chunk(chunk, sample_rate) + await session.handle_client_event({"type": "input_audio_buffer.commit"}) + + assert session.active_task is not None + await asyncio.wait_for(session.active_task, timeout=2.0) + + assert output_track.enqueue_calls + assert {call_sample_rate for _, call_sample_rate in output_track.enqueue_calls} == { + 16000 + } + + echoed_audio = np.concatenate([audio for audio, _ in output_track.enqueue_calls]) + expected_user_audio = _expected_session_audio(audio_i16, sample_rate) + expected_echo = backend._condition_echo_waveform(expected_user_audio) + + np.testing.assert_allclose(echoed_audio, expected_echo, atol=1e-5) + assert np.max(np.abs(echoed_audio)) <= 0.35 + 1e-5 + assert np.sqrt(np.mean(np.square(echoed_audio))) > 0.01 diff --git a/tests/test_realtime_backend_mock.py b/tests/test_realtime_backend_mock.py new file mode 100644 index 000000000..dd86fd9d7 --- /dev/null +++ b/tests/test_realtime_backend_mock.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +import numpy as np +import pytest +import soundfile as sf + +from sglang_omni.realtime.backend import MockResponseBackend, TurnContext + + +@pytest.mark.asyncio +async def test_mock_response_backend_playback_mode_replays_user_audio_verbatim(): + backend = MockResponseBackend( + response_text="Mock backend replayed the captured utterance.", + audio_mode="playback", + inter_chunk_delay_s=0.0, + total_duration_s=0.2, + chunk_duration_s=0.1, + ) + user_audio = np.linspace(-0.25, 0.25, 32, dtype=np.float32) + turn = TurnContext( + session_id="session-1", + history=[], + instructions=None, + user_text="hello", + user_audio=user_audio, + user_audio_sample_rate=16000, + recent_video=None, + recent_video_fps=None, + ) + + events = [event async for event in backend.stream_response(turn)] + audio_events = [event for event in events if event.type == "audio_chunk"] + + assert events[0].type == "response_started" + assert any(event.type == "text_delta" for event in events) + assert audio_events + assert [event.sample_rate for event in audio_events] == [16000] + np.testing.assert_allclose( + np.concatenate([event.audio for event in audio_events]), + user_audio, + ) + assert events[-1].type == "done" + assert events[-1].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_mock_response_backend_echo_mode_conditions_user_audio(): + backend = MockResponseBackend( + audio_mode="echo", + inter_chunk_delay_s=0.0, + total_duration_s=0.2, + chunk_duration_s=0.1, + ) + user_audio = np.linspace(-0.8, 0.9, 32, dtype=np.float32) + 0.2 + turn = TurnContext( + session_id="session-1", + history=[], + instructions=None, + user_text="hello", + user_audio=user_audio, + user_audio_sample_rate=16000, + recent_video=None, + recent_video_fps=None, + ) + + events = [event async for event in backend.stream_response(turn)] + audio_events = [event for event in events if event.type == "audio_chunk"] + + assert audio_events + echoed_audio = np.concatenate([event.audio for event in audio_events]) + np.testing.assert_allclose( + echoed_audio, + backend._condition_echo_waveform(user_audio), + ) + + +@pytest.mark.asyncio +async def test_mock_response_backend_falls_back_to_tone_without_audio(): + backend = MockResponseBackend( + audio_mode="playback", + inter_chunk_delay_s=0.0, + total_duration_s=0.05, + chunk_duration_s=0.05, + sample_rate=24000, + ) + turn = TurnContext( + session_id="session-1", + history=[], + instructions=None, + user_text="hello", + user_audio=None, + user_audio_sample_rate=None, + recent_video=None, + recent_video_fps=None, + ) + + events = [event async for event in backend.stream_response(turn)] + audio_events = [event for event in events if event.type == "audio_chunk"] + + assert audio_events + assert audio_events[0].sample_rate == 24000 + assert np.any(audio_events[0].audio != 0.0) + + +@pytest.mark.asyncio +async def test_mock_response_backend_tone_mode_ignores_user_audio(): + backend = MockResponseBackend( + audio_mode="tone", + inter_chunk_delay_s=0.0, + total_duration_s=0.05, + chunk_duration_s=0.05, + sample_rate=24000, + ) + turn = TurnContext( + session_id="session-1", + history=[], + instructions=None, + user_text="hello", + user_audio=np.linspace(-0.9, 0.9, 128, dtype=np.float32), + user_audio_sample_rate=16000, + recent_video=None, + recent_video_fps=None, + ) + + events = [event async for event in backend.stream_response(turn)] + audio_events = [event for event in events if event.type == "audio_chunk"] + tone_audio = np.concatenate([event.audio for event in audio_events]) + + assert audio_events + assert [event.sample_rate for event in audio_events] == [24000] + assert tone_audio.shape != turn.user_audio.shape + assert np.any(tone_audio != 0.0) + + +@pytest.mark.asyncio +async def test_mock_response_backend_cancel_stops_stream(): + backend = MockResponseBackend( + audio_mode="echo", + inter_chunk_delay_s=0.05, + total_duration_s=0.4, + chunk_duration_s=0.1, + ) + turn = TurnContext( + session_id="session-1", + history=[], + instructions=None, + user_text="hello", + user_audio=np.zeros(32, dtype=np.float32), + user_audio_sample_rate=16000, + recent_video=None, + recent_video_fps=None, + ) + + events = [] + + async def _collect(): + async for event in backend.stream_response(turn): + events.append(event) + if event.type == "response_started": + await backend.cancel(event.response_id) + + await asyncio.wait_for(_collect(), timeout=1.0) + + assert events[0].type == "response_started" + assert events[-1].type == "done" + assert events[-1].finish_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_mock_response_backend_dumps_captured_wav(tmp_path): + backend = MockResponseBackend( + audio_mode="echo", + dump_audio_dir=str(tmp_path), + inter_chunk_delay_s=0.0, + chunk_duration_s=0.05, + ) + user_audio = np.linspace(-0.2, 0.2, 160, dtype=np.float32) + turn = TurnContext( + session_id="session-1", + history=[], + instructions=None, + user_text="hello", + user_audio=user_audio, + user_audio_sample_rate=16000, + recent_video=None, + recent_video_fps=None, + ) + + _events = [event async for event in backend.stream_response(turn)] + + dumped = sorted(tmp_path.glob("*_captured.wav")) + assert len(dumped) == 1 + + dumped_audio, dumped_sr = sf.read(dumped[0], dtype="float32") + assert dumped_sr == 16000 + np.testing.assert_allclose(dumped_audio, user_audio, atol=1e-4) diff --git a/tests/test_realtime_backend_omni.py b/tests/test_realtime_backend_omni.py new file mode 100644 index 000000000..08e37d4c6 --- /dev/null +++ b/tests/test_realtime_backend_omni.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 + +import msgpack +import numpy as np +import pytest + +from sglang_omni.client.types import GenerateChunk +from sglang_omni.realtime.backend import OmniResponseBackend, TurnContext + + +class _FakeClient: + def __init__(self) -> None: + self.cancelled: str | None = None + self.last_request = None + + async def generate(self, request, request_id: str): + self.last_request = request + yield GenerateChunk(request_id=request_id, text="hello") + yield GenerateChunk( + request_id=request_id, + modality="audio", + audio_data=np.array([0.1, -0.1], dtype=np.float32), + sample_rate=24000, + ) + yield GenerateChunk(request_id=request_id, finish_reason="stop") + + async def abort(self, request_id: str) -> None: + self.cancelled = request_id + + +class _SnapshotTextClient: + async def generate(self, request, request_id: str): + del request + yield GenerateChunk(request_id=request_id, text="Hi") + yield GenerateChunk(request_id=request_id, text="Hi there") + yield GenerateChunk(request_id=request_id, text="Hi there") + yield GenerateChunk(request_id=request_id, finish_reason="stop") + + async def abort(self, request_id: str) -> None: + del request_id + + +@pytest.mark.asyncio +async def test_omni_response_backend_normalizes_turn_output(): + client = _FakeClient() + backend = OmniResponseBackend( + client=client, + model="qwen3-omni", + max_new_tokens=32, + output_modalities=("text", "audio"), + ) + turn = TurnContext( + session_id="session-1", + history=[{"role": "assistant", "content": "previous"}], + instructions="be concise", + user_text="hi there", + user_audio=np.zeros(32, dtype=np.float32), + user_audio_sample_rate=16000, + recent_video=None, + recent_video_fps=None, + ) + + events = [event async for event in backend.stream_response(turn)] + + assert [event.type for event in events] == [ + "response_started", + "text_delta", + "audio_chunk", + "done", + ] + assert events[1].text == "hello" + assert events[2].sample_rate == 24000 + assert events[3].finish_reason == "stop" + + request = client.last_request + assert request is not None + assert request.model == "qwen3-omni" + assert request.metadata["audio_target_sr"] == 16000 + assert len(request.metadata["audios"]) == 1 + audio_payload = request.metadata["audios"][0] + assert isinstance(audio_payload["audio_waveform"], bytes) + assert audio_payload["audio_waveform_dtype"] == "float32" + assert audio_payload["audio_waveform_shape"] == [32] + msgpack.packb(request.metadata, use_bin_type=True) + assert len(request.messages) == 3 + assert request.messages[-1].content == "hi there" + + await backend.cancel(events[0].response_id) + assert client.cancelled == events[0].response_id + + +@pytest.mark.asyncio +async def test_omni_response_backend_coerces_snapshot_text_to_deltas(): + backend = OmniResponseBackend( + client=_SnapshotTextClient(), + model="qwen3-omni", + output_modalities=("text",), + ) + turn = TurnContext( + session_id="session-1", + history=[], + instructions=None, + user_text="hello", + user_audio=None, + user_audio_sample_rate=None, + recent_video=None, + recent_video_fps=None, + ) + + events = [event async for event in backend.stream_response(turn)] + text_events = [event for event in events if event.type == "text_delta"] + + assert [event.text for event in text_events] == ["Hi", " there"] + assert "".join(event.text for event in text_events) == "Hi there" diff --git a/tests/test_realtime_media.py b/tests/test_realtime_media.py new file mode 100644 index 000000000..cd1b8aef7 --- /dev/null +++ b/tests/test_realtime_media.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import numpy as np + +from sglang_omni.realtime.media import mono_float32 + + +def test_mono_float32_scales_integer_stereo_before_mixdown(): + stereo = np.array( + [[16384, -16384, 8192, -8192], [16384, -16384, 8192, -8192]], + dtype=np.int16, + ) + + mono = mono_float32(stereo) + + np.testing.assert_allclose( + mono, + np.array([0.5, -0.5, 0.25, -0.25], dtype=np.float32), + atol=1e-5, + ) diff --git a/tests/test_realtime_session.py b/tests/test_realtime_session.py new file mode 100644 index 000000000..b9913be37 --- /dev/null +++ b/tests/test_realtime_session.py @@ -0,0 +1,478 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import json + +import numpy as np +import pytest + +from sglang_omni.realtime.backend import BackendCapabilities, ResponseEvent +from sglang_omni.realtime.session import RealtimeSession, RealtimeSessionConfig +from sglang_omni.realtime.vad import VadConfig + + +class _FakeChannel: + readyState = "open" + + def __init__(self) -> None: + self.messages: list[dict] = [] + + def send(self, raw: str) -> None: + self.messages.append(json.loads(raw)) + + +class _FakeOutputTrack: + def __init__(self) -> None: + self.pending_samples = 0 + self.clear_calls = 0 + self.enqueue_calls: list[tuple[np.ndarray, int]] = [] + + async def clear(self) -> None: + self.clear_calls += 1 + self.pending_samples = 0 + + async def enqueue(self, audio: np.ndarray, sample_rate: int) -> None: + self.enqueue_calls.append((np.asarray(audio), sample_rate)) + self.pending_samples = 0 + + +class _ScriptedBackend: + model_name = "fake-model" + capabilities = BackendCapabilities( + accepts_audio_input=True, + accepts_video_input=True, + returns_text=True, + returns_audio=True, + supports_cancel=True, + ) + + def __init__(self) -> None: + self.turns = [] + self.cancelled: list[str] = [] + + async def stream_response(self, turn): + self.turns.append(turn) + response_idx = len(self.turns) + response_id = f"resp-{response_idx}" + yield ResponseEvent(type="response_started", response_id=response_id) + yield ResponseEvent( + type="text_delta", + response_id=response_id, + text=f"answer-{response_idx}", + ) + yield ResponseEvent( + type="audio_chunk", + response_id=response_id, + audio=np.array([0.1, -0.1], dtype=np.float32), + sample_rate=24000, + ) + yield ResponseEvent( + type="done", + response_id=response_id, + finish_reason="stop", + ) + + async def cancel(self, response_id: str) -> None: + self.cancelled.append(response_id) + + +class _BlockingBackend: + model_name = "fake-model" + capabilities = BackendCapabilities( + accepts_audio_input=True, + returns_text=True, + returns_audio=False, + supports_cancel=True, + ) + + def __init__(self) -> None: + self.turns = [] + self.cancelled: list[str] = [] + self.started = asyncio.Event() + self.released = asyncio.Event() + + async def stream_response(self, turn): + self.turns.append(turn) + response_id = "resp-cancel" + yield ResponseEvent(type="response_started", response_id=response_id) + self.started.set() + await self.released.wait() + yield ResponseEvent( + type="done", + response_id=response_id, + finish_reason="cancelled", + ) + + async def cancel(self, response_id: str) -> None: + self.cancelled.append(response_id) + self.released.set() + + +class _FakeVadEvent: + def __init__( + self, *, speech_started: bool = False, speech_stopped: bool = False + ) -> None: + self.speech_started = speech_started + self.speech_stopped = speech_stopped + + +class _FakeVad: + def __init__(self) -> None: + self.speaking = False + self.last_frame_count = 2 + self.last_voiced_frame_count = 0 + self.last_speech_ratio = 0.0 + self._call_count = 0 + + def measure_level(self, audio: np.ndarray) -> float: + audio = np.asarray(audio, dtype=np.float32).reshape(-1) + if audio.size == 0: + return 0.0 + return float(np.sqrt(np.mean(np.square(audio)))) + + def process(self, _audio: np.ndarray) -> _FakeVadEvent: + self._call_count += 1 + phase = ((self._call_count - 1) % 3) + 1 + if phase == 1: + self.last_voiced_frame_count = 2 + self.last_speech_ratio = 1.0 + return _FakeVadEvent() + if phase == 2: + self.speaking = True + self.last_voiced_frame_count = 2 + self.last_speech_ratio = 1.0 + return _FakeVadEvent(speech_started=True) + + self.speaking = False + self.last_voiced_frame_count = 0 + self.last_speech_ratio = 0.0 + return _FakeVadEvent(speech_stopped=True) + + def reset(self) -> None: + self.speaking = False + self.last_frame_count = 0 + self.last_voiced_frame_count = 0 + self.last_speech_ratio = 0.0 + self._call_count = 0 + + +def _make_session( + backend, +) -> tuple[RealtimeSession, _FakeOutputTrack, _FakeChannel]: + output_track = _FakeOutputTrack() + channel = _FakeChannel() + session = RealtimeSession( + session_id="session-1", + backend=backend, + output_track=output_track, + config=RealtimeSessionConfig( + vad=VadConfig( + start_threshold=0.02, + stop_threshold=0.01, + min_speech_s=0.1, + min_silence_s=0.1, + preroll_s=0.0, + ), + ), + ) + session.vad = _FakeVad() + session.attach_event_channel(channel) + return session, output_track, channel + + +async def _drive_turn( + session: RealtimeSession, + *, + user_text: str, + start_ts: float, +) -> None: + await session.handle_client_event( + { + "type": "conversation.item.create", + "item": {"role": "user", "content": user_text}, + } + ) + + speech = np.full(1600, 0.2, dtype=np.float32) + silence = np.zeros(1600, dtype=np.float32) + + await session.handle_audio_chunk(speech, 16000, timestamp=start_ts) + await session.handle_audio_chunk(speech, 16000, timestamp=start_ts + 0.1) + await session.handle_audio_chunk(silence, 16000, timestamp=start_ts + 0.2) + + task = session.active_task + assert task is not None + await asyncio.wait_for(task, timeout=1.0) + + +async def _drive_text_turn( + session: RealtimeSession, + *, + user_text: str, +) -> None: + await session.handle_client_event( + { + "type": "conversation.item.create", + "item": {"role": "user", "content": user_text}, + } + ) + await session.handle_client_event({"type": "response.create"}) + + task = session.active_task + assert task is not None + await asyncio.wait_for(task, timeout=1.0) + + +@pytest.mark.asyncio +async def test_realtime_session_runs_turns_with_fake_backend_and_history(): + backend = _ScriptedBackend() + session, output_track, channel = _make_session(backend) + + frame = np.zeros((8, 8, 3), dtype=np.uint8) + await session.handle_video_frame(frame, timestamp=1.0) + await session.handle_video_frame(frame, timestamp=1.6) + + await _drive_turn(session, user_text="describe this", start_ts=2.0) + await _drive_turn(session, user_text="follow up", start_ts=4.0) + + assert len(backend.turns) == 2 + assert backend.turns[0].user_text == "describe this" + assert backend.turns[0].recent_video is not None + assert backend.turns[0].recent_video.shape[0] == 2 + assert backend.turns[1].history == [ + {"role": "user", "content": "describe this"}, + {"role": "assistant", "content": "answer-1"}, + ] + + assert session.history == [ + {"role": "user", "content": "describe this"}, + {"role": "assistant", "content": "answer-1"}, + {"role": "user", "content": "follow up"}, + {"role": "assistant", "content": "answer-2"}, + ] + + assert len(output_track.enqueue_calls) == 2 + assert output_track.enqueue_calls[0][1] == 24000 + assert output_track.clear_calls >= 2 + + event_types = [event["type"] for event in channel.messages] + assert event_types.count("conversation.item.created") == 2 + assert event_types.count("input_audio_buffer.chunk_received") == 2 + assert event_types.count("input_audio_buffer.speech_started") == 2 + assert event_types.count("input_audio_buffer.speech_stopped") == 2 + assert event_types.count("input_video_buffer.frame_received") == 1 + assert event_types.count("turn.prepared") == 2 + assert event_types.count("response.created") == 2 + assert event_types.count("response.output_text.delta") == 2 + assert event_types.count("response.output_audio.delta") == 2 + assert event_types.count("response.done") == 2 + + frame_event = next( + event + for event in channel.messages + if event["type"] == "input_video_buffer.frame_received" + ) + assert frame_event["frame_count"] == 1 + assert frame_event["buffered_frames"] == 1 + + audio_chunk_events = [ + event + for event in channel.messages + if event["type"] == "input_audio_buffer.chunk_received" + ] + assert all(event["sample_count"] > 0 for event in audio_chunk_events) + assert all(event["sample_rate"] == 16000 for event in audio_chunk_events) + assert all(event["rms"] >= 0.0 for event in audio_chunk_events) + assert all("dc_offset" in event for event in audio_chunk_events) + assert all("frame_count" in event for event in audio_chunk_events) + assert all("voiced_frame_count" in event for event in audio_chunk_events) + assert all("speech_ratio" in event for event in audio_chunk_events) + assert all("speaking_before" in event for event in audio_chunk_events) + assert all("speaking_after" in event for event in audio_chunk_events) + assert [event["chunk_count"] for event in audio_chunk_events] == [1, 4] + + turn_events = [ + event for event in channel.messages if event["type"] == "turn.prepared" + ] + assert all(event["audio_sample_count"] > 0 for event in turn_events) + assert turn_events[0]["video_frame_count"] == 2 + assert turn_events[0]["video_fps"] is not None + + +@pytest.mark.asyncio +async def test_realtime_session_cancel_delegates_to_backend(): + backend = _BlockingBackend() + session, output_track, channel = _make_session(backend) + + drive_task = asyncio.create_task( + _drive_turn(session, user_text="cancel this", start_ts=2.0) + ) + await asyncio.wait_for(backend.started.wait(), timeout=1.0) + + await session.handle_client_event({"type": "response.cancel"}) + await asyncio.wait_for(drive_task, timeout=1.0) + + assert backend.cancelled == ["resp-cancel"] + assert output_track.clear_calls >= 2 + assert any(event["type"] == "response.cancelled" for event in channel.messages) + + +@pytest.mark.asyncio +async def test_realtime_session_auto_vad_barge_in_cancels_and_queues_next_turn(): + backend = _BlockingBackend() + session, output_track, channel = _make_session(backend) + + speech = np.full(1600, 0.2, dtype=np.float32) + silence = np.zeros(1600, dtype=np.float32) + + await session.handle_client_event( + { + "type": "conversation.item.create", + "item": {"role": "user", "content": "first request"}, + } + ) + await session.handle_audio_chunk(speech, 16000, timestamp=1.0) + await session.handle_audio_chunk(speech, 16000, timestamp=1.1) + await session.handle_audio_chunk(silence, 16000, timestamp=1.2) + + await asyncio.wait_for(backend.started.wait(), timeout=1.0) + assert len(backend.turns) == 1 + + await session.handle_client_event( + { + "type": "conversation.item.create", + "item": {"role": "user", "content": "second request"}, + } + ) + await session.handle_audio_chunk(speech, 16000, timestamp=2.0) + await session.handle_audio_chunk(speech, 16000, timestamp=2.1) + await session.handle_audio_chunk(silence, 16000, timestamp=2.2) + + deadline = asyncio.get_running_loop().time() + 1.0 + while ( + session.active_task is not None and asyncio.get_running_loop().time() < deadline + ): + await asyncio.sleep(0.01) + + assert backend.cancelled == ["resp-cancel"] + assert len(backend.turns) == 2 + assert backend.turns[1].user_text == "second request" + assert output_track.clear_calls >= 2 + assert any( + event["type"] == "response.cancelled" and event["reason"] == "barge_in" + for event in channel.messages + ) + + +@pytest.mark.asyncio +async def test_realtime_session_supports_manual_push_to_talk_commit(): + backend = _ScriptedBackend() + session, output_track, channel = _make_session(backend) + + await session.handle_client_event({"type": "input_audio_buffer.start"}) + await session.handle_audio_chunk( + np.full(1600, 0.2, dtype=np.float32), 16000, timestamp=1.0 + ) + await session.handle_audio_chunk( + np.full(1600, 0.1, dtype=np.float32), 16000, timestamp=1.1 + ) + await session.handle_client_event({"type": "input_audio_buffer.commit"}) + + task = session.active_task + assert task is not None + await asyncio.wait_for(task, timeout=1.0) + + assert session.turn_mode == "manual" + assert session.manual_recording is False + assert len(backend.turns) == 1 + np.testing.assert_allclose( + backend.turns[0].user_audio, + np.concatenate( + [ + np.full(1600, 0.2, dtype=np.float32), + np.full(1600, 0.1, dtype=np.float32), + ] + ), + ) + assert any( + event["type"] == "input_audio_buffer.manual_started" + for event in channel.messages + ) + assert any( + event["type"] == "input_audio_buffer.manual_committed" + and event["empty"] is False + for event in channel.messages + ) + assert any(event["type"] == "turn.prepared" for event in channel.messages) + assert any(event["type"] == "response.done" for event in channel.messages) + assert len(output_track.enqueue_calls) == 1 + + +@pytest.mark.asyncio +async def test_realtime_session_supports_text_only_turns_with_history(): + backend = _ScriptedBackend() + session, output_track, channel = _make_session(backend) + + await _drive_text_turn(session, user_text="hello there") + await _drive_text_turn(session, user_text="follow up") + + assert len(backend.turns) == 2 + assert backend.turns[0].user_text == "hello there" + assert backend.turns[0].user_audio is None + assert backend.turns[1].history == [ + {"role": "user", "content": "hello there"}, + {"role": "assistant", "content": "answer-1"}, + ] + assert session.history == [ + {"role": "user", "content": "hello there"}, + {"role": "assistant", "content": "answer-1"}, + {"role": "user", "content": "follow up"}, + {"role": "assistant", "content": "answer-2"}, + ] + assert len(output_track.enqueue_calls) == 2 + + turn_events = [ + event for event in channel.messages if event["type"] == "turn.prepared" + ] + assert len(turn_events) == 2 + assert all(event["audio_sample_count"] == 0 for event in turn_events) + assert all(event["audio_sample_rate"] is None for event in turn_events) + + +@pytest.mark.asyncio +async def test_realtime_session_can_switch_between_vad_and_manual_modes(): + backend = _ScriptedBackend() + session, _output_track, channel = _make_session(backend) + + assert session.turn_mode == "vad" + + await session.handle_client_event( + { + "type": "session.update", + "session": {"audio": {"input_mode": "manual"}}, + } + ) + + assert session.turn_mode == "manual" + assert session.manual_recording is False + assert channel.messages[-1]["type"] == "session.updated" + assert channel.messages[-1]["session"]["audio"]["input_mode"] == "manual" + + await session.handle_audio_chunk( + np.full(1600, 0.2, dtype=np.float32), 16000, timestamp=1.0 + ) + assert session.active_task is None + + await session.handle_client_event( + { + "type": "session.update", + "session": {"audio": {"input_mode": "vad"}}, + } + ) + + assert session.turn_mode == "vad" + assert session.manual_recording is False + assert channel.messages[-1]["type"] == "session.updated" + assert channel.messages[-1]["session"]["audio"]["input_mode"] == "vad" diff --git a/tests/test_realtime_utils.py b/tests/test_realtime_utils.py new file mode 100644 index 000000000..992799e53 --- /dev/null +++ b/tests/test_realtime_utils.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from sglang_omni.realtime.utils import throttle + + +class _Recorder: + def __init__(self) -> None: + self.calls: list[tuple[str, float]] = [] + self._throttle_state: dict[str, float] = {} + + @throttle(0.5, timestamp_kw="timestamp") + async def record(self, label: str, *, timestamp: float) -> None: + self.calls.append((label, timestamp)) + + +@pytest.mark.asyncio +async def test_throttle_decorator_suppresses_calls_within_interval(): + recorder = _Recorder() + + await recorder.record("first", timestamp=1.0) + await recorder.record("suppressed", timestamp=1.2) + await recorder.record("second", timestamp=1.6) + + assert recorder.calls == [("first", 1.0), ("second", 1.6)] + + +@pytest.mark.asyncio +async def test_throttle_decorator_is_per_instance(): + left = _Recorder() + right = _Recorder() + + await left.record("left", timestamp=1.0) + await right.record("right", timestamp=1.1) + + assert left.calls == [("left", 1.0)] + assert right.calls == [("right", 1.1)] diff --git a/tests/test_realtime_vad.py b/tests/test_realtime_vad.py new file mode 100644 index 000000000..0bc458d5c --- /dev/null +++ b/tests/test_realtime_vad.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +from sglang_omni.realtime.vad import EnergyVad, VadConfig + + +def _dummy_chunk(num_samples: int = 3200) -> np.ndarray: + return np.zeros(num_samples, dtype=np.float32) + + +def test_webrtc_vad_detects_start_and_stop_with_scripted_votes(monkeypatch): + vad = EnergyVad( + VadConfig( + min_speech_s=0.06, + min_silence_s=0.08, + frame_duration_ms=20, + ) + ) + votes = iter( + [ + True, + True, + True, + False, + False, + False, + False, + ] + ) + monkeypatch.setattr(vad, "_detect_frame", lambda _frame: next(votes)) + + start_event = vad.process(_dummy_chunk(960)) + assert start_event.speech_started is True + assert vad.speaking is True + + stop_event = vad.process(_dummy_chunk(1280)) + assert stop_event.speech_stopped is True + assert vad.speaking is False + + +def test_webrtc_vad_ignores_short_spike(monkeypatch): + vad = EnergyVad( + VadConfig( + min_speech_s=0.10, + min_silence_s=0.10, + frame_duration_ms=20, + ) + ) + votes = iter([True, False, False, False, False]) + monkeypatch.setattr(vad, "_detect_frame", lambda _frame: next(votes)) + + event = vad.process(_dummy_chunk(1600)) + + assert event.speech_started is False + assert vad.speaking is False + + +def test_webrtc_vad_tracks_frame_statistics(monkeypatch): + vad = EnergyVad(VadConfig(frame_duration_ms=20)) + votes = iter([True, False, True, True]) + monkeypatch.setattr(vad, "_detect_frame", lambda _frame: next(votes)) + + vad.process(_dummy_chunk(1280)) + + assert vad.last_frame_count == 4 + assert vad.last_voiced_frame_count == 3 + assert vad.last_speech_ratio == 0.75 diff --git a/tests/test_realtime_ws_api.py b/tests/test_realtime_ws_api.py new file mode 100644 index 000000000..222f5565b --- /dev/null +++ b/tests/test_realtime_ws_api.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json + +import numpy as np +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from sglang_omni.realtime.backend import MockResponseBackend +from sglang_omni.serve.realtime_ws_api import create_realtime_ws_router + + +def _make_app( + *, + output_modalities: tuple[str, ...] = ("text", "audio"), + audio_mode: str = "playback", +) -> FastAPI: + app = FastAPI() + + def backend_factory(model_name: str, max_new_tokens: int) -> MockResponseBackend: + del max_new_tokens + return MockResponseBackend( + model=model_name, + output_modalities=output_modalities, + response_text="Mock websocket response.", + audio_mode=audio_mode, + inter_chunk_delay_s=0.0, + chunk_duration_s=0.05, + total_duration_s=0.1, + ) + + app.include_router( + create_realtime_ws_router( + model_name="mock-realtime-ws", + backend_factory=backend_factory, + ) + ) + return app + + +def test_realtime_ws_emits_session_created_event(): + client = TestClient(_make_app()) + + with client.websocket_connect("/v1/realtime/ws?model=demo-model") as websocket: + event = json.loads(websocket.receive_text()) + + assert event["type"] == "session.created" + assert event["model"] == "demo-model" + assert event["transport"] == {"type": "websocket"} + assert event["audio"]["input_encoding"] == "pcm16le" + assert event["audio"]["output_sample_rate"] == 24000 + + +def test_realtime_ws_supports_text_only_turns(): + client = TestClient(_make_app(output_modalities=("text",))) + + with client.websocket_connect("/v1/realtime/ws") as websocket: + created = json.loads(websocket.receive_text()) + assert created["type"] == "session.created" + + websocket.send_text( + json.dumps( + { + "type": "conversation.item.create", + "item": {"role": "user", "content": "hello there"}, + } + ) + ) + websocket.send_text(json.dumps({"type": "response.create"})) + + seen_types: list[str] = [] + done_event = None + for _ in range(16): + event = json.loads(websocket.receive_text()) + seen_types.append(event["type"]) + if event["type"] == "response.done": + done_event = event + break + + assert "conversation.item.created" in seen_types + assert "turn.prepared" in seen_types + assert "response.created" in seen_types + assert "response.output_text.delta" in seen_types + assert done_event is not None + assert done_event["text"] == "Mock websocket response." + + +def test_realtime_ws_accepts_binary_pcm_audio_in_manual_mode(): + client = TestClient(_make_app(output_modalities=("text",))) + + with client.websocket_connect("/v1/realtime/ws") as websocket: + created = json.loads(websocket.receive_text()) + assert created["type"] == "session.created" + + websocket.send_text( + json.dumps( + { + "type": "input_audio_format", + "sample_rate": 16000, + "encoding": "pcm16le", + } + ) + ) + updated = json.loads(websocket.receive_text()) + assert updated["type"] == "input_audio_format.updated" + assert updated["sample_rate"] == 16000 + + websocket.send_text( + json.dumps( + { + "type": "session.update", + "session": {"audio": {"input_mode": "manual"}}, + } + ) + ) + mode_event = json.loads(websocket.receive_text()) + assert mode_event["type"] == "session.updated" + assert mode_event["session"]["audio"]["input_mode"] == "manual" + + websocket.send_text(json.dumps({"type": "input_audio_buffer.start"})) + manual_started = json.loads(websocket.receive_text()) + assert manual_started["type"] == "input_audio_buffer.manual_started" + + pcm = (np.sin(np.linspace(0.0, np.pi * 8.0, 1600)) * 12000.0).astype("