-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathpipeline-patient.py
More file actions
341 lines (287 loc) · 12.5 KB
/
pipeline-patient.py
File metadata and controls
341 lines (287 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Voice Agent WebRTC Pipeline.
This module implements a voice agent pipeline using WebRTC for real-time
speech-to-speech communication with dynamic prompt support.
"""
import argparse
import asyncio
import json
import os
import sys
import uuid
from pathlib import Path
import httpx
import uvicorn
import yaml
from config import Config
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import InputAudioRawFrame, LLMMessagesFrame, TTSAudioRawFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
from pipecat.transports.network.webrtc_connection import (
IceServer,
SmallWebRTCConnection,
)
from websocket_transcript_output import WebsocketTranscriptOutput
from nvidia_pipecat.processors.audio_util import AudioRecorder
from nvidia_pipecat.processors.nvidia_context_aggregator import (
NvidiaTTSResponseCacher,
create_nvidia_context_aggregator,
)
from nvidia_pipecat.processors.transcript_synchronization import (
BotTranscriptSynchronization,
UserTranscriptSynchronization,
)
from nvidia_pipecat.services.nvidia_rag import NvidiaRAGService
from nvidia_pipecat.services.riva_speech import RivaASRService, RivaTTSService
load_dotenv(override=True)
config_path = os.getenv("CONFIG_PATH")
if not config_path:
raise ValueError("CONFIG_PATH environment variable is not set")
try:
config = Config(**yaml.safe_load(Path(config_path).read_text()))
except FileNotFoundError as e:
raise FileNotFoundError(f"Config file not found at: {config_path}") from e
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in config file: {e}") from e
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Store connections by pc_id
pcs_map: dict[str, SmallWebRTCConnection] = {}
contexts_map: dict[str, OpenAILLMContext] = {}
ice_servers = (
[
IceServer(
urls=os.getenv("TURN_SERVER_URL", ""),
username=os.getenv("TURN_USERNAME", ""),
credential=os.getenv("TURN_PASSWORD", ""),
)
]
if os.getenv("TURN_SERVER_URL")
else []
)
async def run_bot(webrtc_connection, ws: WebSocket):
"""Run the voice agent bot with WebRTC connection and WebSocket.
Args:
webrtc_connection: The WebRTC connection for audio streaming
ws: WebSocket connection for communication
"""
stream_id = uuid.uuid4()
transport_params = TransportParams(
audio_in_enabled=True,
audio_in_sample_rate=16000,
audio_out_sample_rate=16000,
audio_out_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
audio_out_10ms_chunks=5,
)
transport = SmallWebRTCTransport(
webrtc_connection=webrtc_connection,
params=transport_params,
)
agent = NvidiaRAGService(
collection_name=config.NvidiaRAGService.collection_name,
enable_citations=config.NvidiaRAGService.enable_citations,
rag_server_url=config.NvidiaRAGService.rag_server_url,
use_knowledge_base=config.NvidiaRAGService.use_knowledge_base,
max_tokens=config.NvidiaRAGService.max_tokens,
filler=config.Pipeline.filler,
session = httpx.AsyncClient(timeout=float(os.getenv("REQUEST_TIMEOUT", 15.0)))
)
stt = RivaASRService(
server=config.RivaASRService.server,
api_key=os.getenv("NVIDIA_API_KEY"),
language=config.RivaASRService.language,
sample_rate=config.RivaASRService.sample_rate,
automatic_punctuation=True,
model=config.RivaASRService.model,
function_id=config.RivaASRService.function_id,
)
# Load IPA dictionary with error handling
ipa_file = Path(__file__).parent / "ipa.json"
try:
with open(ipa_file, encoding="utf-8") as f:
ipa_dict = json.load(f)
except FileNotFoundError as e:
logger.error(f"IPA dictionary file not found at {ipa_file}")
raise FileNotFoundError(f"IPA dictionary file not found at {ipa_file}") from e
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in IPA dictionary file: {e}")
raise ValueError(f"Invalid JSON in IPA dictionary file: {e}") from e
except Exception as e:
logger.error(f"Error loading IPA dictionary: {e}")
raise
tts = RivaTTSService(
server=config.RivaTTSService.server,
api_key=os.getenv("NVIDIA_API_KEY"),
voice_id=config.RivaTTSService.voice_id,
model=config.RivaTTSService.model,
function_id=config.RivaTTSService.function_id,
language=config.RivaTTSService.language,
zero_shot_audio_prompt_file=(
Path(os.getenv("ZERO_SHOT_AUDIO_PROMPT")) if os.getenv("ZERO_SHOT_AUDIO_PROMPT") else None
),
ipa_dict=ipa_dict,
)
# by default, audio recording is disabled,
# if you want to record audio, set the environment variable DUMP_AUDIO_FILES to true in ace_controller.env file
enable_audio_recording = os.getenv("DUMP_AUDIO_FILES", "false").lower() == "true"
if enable_audio_recording:
# Create audio_dumps directory if it doesn't exist
audio_dumps_dir = Path(__file__).parent / "audio_dumps"
audio_dumps_dir.mkdir(exist_ok=True)
asr_recorder = AudioRecorder(
output_file=str(audio_dumps_dir / f"asr_recording_{stream_id}.wav"),
params=transport_params,
frame_type=InputAudioRawFrame,
)
tts_recorder = AudioRecorder(
output_file=str(audio_dumps_dir / f"tts_recording_{stream_id}.wav"),
params=transport_params,
frame_type=TTSAudioRawFrame,
)
else:
asr_recorder = None
tts_recorder = None
# Used to synchronize the user and bot transcripts in the UI
stt_transcript_synchronization = UserTranscriptSynchronization()
tts_transcript_synchronization = BotTranscriptSynchronization()
messages = [
]
context = OpenAILLMContext(messages)
# Store context globally so WebSocket can access it
pc_id = webrtc_connection.pc_id
contexts_map[pc_id] = context
# Configure speculative speech processing based on environment variable
# set this variable to true only if your agent backend does not retain every incoming request and the agent response in memory
# we will keep this set to false since the healthcare agent retains memory in langgraph
enable_speculative_speech = os.getenv("ENABLE_SPECULATIVE_SPEECH", "false").lower() == "true"
if enable_speculative_speech:
context_aggregator = create_nvidia_context_aggregator(context, send_interims=True)
tts_response_cacher = NvidiaTTSResponseCacher()
else:
context_aggregator = agent.create_context_aggregator(context)
tts_response_cacher = None
transcript_processor_output = WebsocketTranscriptOutput(ws)
pipeline = Pipeline(
[
transport.input(), # Websocket input from client
*([asr_recorder] if asr_recorder else []), # Include asr_recorder only if enabled
stt, # Speech-To-Text
stt_transcript_synchronization,
context_aggregator.user(),
agent, # Agent Backend
tts, # Text-To-Speech
*([tts_recorder] if tts_recorder else []), # Include tts_recorder only if enabled
*([tts_response_cacher] if tts_response_cacher else []), # Include cacher only if enabled
tts_transcript_synchronization,
transcript_processor_output,
transport.output(), # Websocket output to client
context_aggregator.assistant(),
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
send_initial_empty_metrics=True,
start_metadata={"stream_id": stream_id},
),
)
@transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
# Wait 50ms for custom prompt from UI before starting conversation
await asyncio.sleep(0.05)
# Kick off the conversation.
# messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for handling voice agent connections.
Args:
websocket: The WebSocket connection to handle
"""
await websocket.accept()
try:
request = await websocket.receive_json()
pc_id = request.get("pc_id")
if pc_id and pc_id in pcs_map:
pipecat_connection = pcs_map[pc_id]
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(sdp=request["sdp"], type=request["type"])
else:
pipecat_connection = SmallWebRTCConnection(ice_servers)
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
@pipecat_connection.event_handler("closed")
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}")
pcs_map.pop(webrtc_connection.pc_id, None) # Remove connection reference
contexts_map.pop(webrtc_connection.pc_id, None) # Remove context reference
asyncio.create_task(run_bot(pipecat_connection, websocket))
answer = pipecat_connection.get_answer()
pcs_map[answer["pc_id"]] = pipecat_connection
await websocket.send_json(answer)
# Keep the connection open and print text messages
while True:
try:
message = await websocket.receive_text()
# Parse JSON message from UI
try:
data = json.loads(message)
message = data.get("message", "").strip()
if data.get("type") == "context_reset" and message:
print(f"Received context reset from UI: {message}")
logger.info(f"Context reset from UI: {message}")
# Replace entire conversation context with new system prompt
pc_id = pipecat_connection.pc_id
if pc_id in contexts_map:
context = contexts_map[pc_id]
context.set_messages([{"role": "system", "content": message}])
else:
print(f"No context found for pc_id: {pc_id}")
except json.JSONDecodeError:
print(f"Non-JSON message: {message}")
except Exception as e:
logger.error(f"Error processing message: {e}")
break
except WebSocketDisconnect:
logger.info("Client disconnected from websocket")
@app.get("/get_prompt")
async def get_prompt():
"""Get the default system prompt."""
return {
"prompt": "Not set in ace-controller, set in agent backend",
"name": "System Prompt",
"description": "Default system prompt for the System as set at the backend",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="WebRTC demo")
parser.add_argument("--host", default="0.0.0.0", help="Host for HTTP server (default: localhost)")
parser.add_argument("--port", type=int, default=7860, help="Port for HTTP server (default: 7860)")
parser.add_argument("--verbose", "-v", action="count")
args = parser.parse_args()
logger.remove(0)
if args.verbose:
logger.add(sys.stderr, level="TRACE")
else:
logger.add(sys.stderr, level="DEBUG")
uvicorn.run(app, host=args.host, port=args.port)