|
1 | 1 | import json |
| 2 | +import logging |
2 | 3 | from contextlib import asynccontextmanager |
3 | 4 | from os import getenv |
4 | 5 |
|
5 | 6 | from fastapi import FastAPI, HTTPException |
| 7 | +from fastapi.responses import StreamingResponse |
6 | 8 | from langchain_core.messages import HumanMessage, AIMessage, ToolMessage |
7 | 9 | from pydantic import BaseModel |
8 | 10 |
|
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
9 | 13 | from langgraph_react_agent_base.agent import get_graph_closure |
10 | 14 |
|
11 | 15 |
|
@@ -137,6 +141,71 @@ async def chat(request: ChatRequest): |
137 | 141 | ) |
138 | 142 |
|
139 | 143 |
|
| 144 | +@app.post("/stream") |
| 145 | +async def stream(request: ChatRequest): |
| 146 | + """ |
| 147 | + Streaming chat endpoint that accepts a message and returns the agent's |
| 148 | + response as Server-Sent Events (SSE). |
| 149 | +
|
| 150 | + Event types: |
| 151 | + - token: streamed text token from the LLM |
| 152 | + - tool_call: tool invocation by the agent |
| 153 | + - tool_result: result returned by a tool |
| 154 | + - done: signals the stream is complete |
| 155 | +
|
| 156 | + Args: |
| 157 | + request: ChatRequest containing the user message |
| 158 | + """ |
| 159 | + global agent_graph |
| 160 | + |
| 161 | + if agent_graph is None: |
| 162 | + raise HTTPException(status_code=503, detail="Agent not initialized") |
| 163 | + |
| 164 | + async def event_generator(): |
| 165 | + try: |
| 166 | + messages = [HumanMessage(content=request.message)] |
| 167 | + |
| 168 | + async for event in agent_graph.astream_events( |
| 169 | + {"messages": messages}, |
| 170 | + config={"recursion_limit": 10}, |
| 171 | + version="v2", |
| 172 | + ): |
| 173 | + kind = event["event"] |
| 174 | + |
| 175 | + # LLM streaming tokens |
| 176 | + if kind == "on_chat_model_stream": |
| 177 | + chunk = event["data"]["chunk"] |
| 178 | + if chunk.content: |
| 179 | + yield f"event: token\ndata: {json.dumps({'content': chunk.content})}\n\n" |
| 180 | + |
| 181 | + # Complete tool call (after LLM finishes generating the call) |
| 182 | + elif kind == "on_chat_model_end": |
| 183 | + message = event["data"]["output"] |
| 184 | + if hasattr(message, "tool_calls") and message.tool_calls: |
| 185 | + for tc in message.tool_calls: |
| 186 | + yield f"event: tool_call\ndata: {json.dumps({'name': tc['name'], 'args': tc['args']})}\n\n" |
| 187 | + |
| 188 | + # Tool execution results |
| 189 | + elif kind == "on_tool_end": |
| 190 | + output = event["data"].get("output", "") |
| 191 | + # Extract content from ToolMessage if present |
| 192 | + if hasattr(output, "content"): |
| 193 | + output = output.content |
| 194 | + yield f"event: tool_result\ndata: {json.dumps({'name': event.get('name', ''), 'output': str(output)})}\n\n" |
| 195 | + |
| 196 | + yield "event: done\ndata: {}\n\n" |
| 197 | + |
| 198 | + except Exception as e: |
| 199 | + logger.exception("Error in stream event_generator") |
| 200 | + yield f"event: error\ndata: {json.dumps({'detail': 'Internal server error'})}\n\n" |
| 201 | + |
| 202 | + return StreamingResponse( |
| 203 | + event_generator(), |
| 204 | + media_type="text/event-stream", |
| 205 | + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, |
| 206 | + ) |
| 207 | + |
| 208 | + |
140 | 209 | @app.get("/health") |
141 | 210 | async def health(): |
142 | 211 | """Return service health and whether the agent graph has been initialized.""" |
|
0 commit comments