|
1 | 1 | import json |
| 2 | +import logging |
2 | 3 | from contextlib import asynccontextmanager |
3 | 4 | from os import getenv |
4 | 5 |
|
5 | 6 | from fastapi import FastAPI, HTTPException |
| 7 | +from fastapi.responses import StreamingResponse |
6 | 8 | from llama_index_workflow_agent_base.agent import get_workflow_closure |
| 9 | +from llama_index_workflow_agent_base.workflow import ToolCallEvent, InputEvent |
7 | 10 | from pydantic import BaseModel |
8 | 11 |
|
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
9 | 14 |
|
10 | 15 | # Request/Response models |
11 | 16 | class ChatRequest(BaseModel): |
@@ -189,6 +194,66 @@ async def chat(request: ChatRequest): |
189 | 194 | ) |
190 | 195 |
|
191 | 196 |
|
| 197 | +@app.post("/stream") |
| 198 | +async def stream(request: ChatRequest): |
| 199 | + """ |
| 200 | + Streaming chat endpoint that accepts a message and returns the agent's |
| 201 | + response as Server-Sent Events (SSE). |
| 202 | +
|
| 203 | + Event types: |
| 204 | + - tool_call: tool invocation by the agent |
| 205 | + - tool_result: result returned by a tool |
| 206 | + - token: final answer text |
| 207 | + - done: signals the stream is complete |
| 208 | +
|
| 209 | + Args: |
| 210 | + request: ChatRequest containing the user message |
| 211 | + """ |
| 212 | + global get_agent |
| 213 | + |
| 214 | + if get_agent is None: |
| 215 | + raise HTTPException(status_code=503, detail="Agent not initialized") |
| 216 | + |
| 217 | + async def event_generator(): |
| 218 | + try: |
| 219 | + agent = get_agent() |
| 220 | + messages = [{"role": "user", "content": request.message}] |
| 221 | + |
| 222 | + handler = agent.run(input=messages) |
| 223 | + |
| 224 | + async for event in handler.stream_events(): |
| 225 | + if isinstance(event, ToolCallEvent): |
| 226 | + for tc in event.tool_calls: |
| 227 | + yield f"event: tool_call\ndata: {json.dumps({'name': tc.tool_name, 'args': tc.tool_kwargs})}\n\n" |
| 228 | + |
| 229 | + elif isinstance(event, InputEvent): |
| 230 | + # Check if the last message is a tool result |
| 231 | + if event.input: |
| 232 | + last_msg = event.input[-1] |
| 233 | + if getattr(last_msg, "role", None) == "tool": |
| 234 | + additional = getattr(last_msg, "additional_kwargs", {}) or {} |
| 235 | + yield f"event: tool_result\ndata: {json.dumps({'name': additional.get('name', ''), 'output': _get_message_content(last_msg)})}\n\n" |
| 236 | + |
| 237 | + result = await handler |
| 238 | + # Extract final answer from the result |
| 239 | + if result and "response" in result: |
| 240 | + content = _get_message_content(result["response"].message) |
| 241 | + if content: |
| 242 | + yield f"event: token\ndata: {json.dumps({'content': content})}\n\n" |
| 243 | + |
| 244 | + yield "event: done\ndata: {}\n\n" |
| 245 | + |
| 246 | + except Exception as e: |
| 247 | + logger.exception("Error in stream event_generator") |
| 248 | + yield f"event: error\ndata: {json.dumps({'detail': 'Internal server error'})}\n\n" |
| 249 | + |
| 250 | + return StreamingResponse( |
| 251 | + event_generator(), |
| 252 | + media_type="text/event-stream", |
| 253 | + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, |
| 254 | + ) |
| 255 | + |
| 256 | + |
192 | 257 | @app.get("/health") |
193 | 258 | async def health(): |
194 | 259 | """Return service health and whether the workflow closure has been initialized.""" |
|
0 commit comments