|
3 | 3 | from os import getenv |
4 | 4 |
|
5 | 5 | from fastapi import FastAPI, HTTPException |
| 6 | +from fastapi.responses import StreamingResponse |
6 | 7 | from langchain_core.messages import HumanMessage, AIMessage, ToolMessage |
7 | 8 | from pydantic import BaseModel |
8 | 9 |
|
@@ -141,6 +142,69 @@ async def chat(request: ChatRequest): |
141 | 142 | ) |
142 | 143 |
|
143 | 144 |
|
| 145 | +@app.post("/stream") |
| 146 | +async def stream(request: ChatRequest): |
| 147 | + """ |
| 148 | + Streaming chat endpoint that accepts a message and returns the agent's |
| 149 | + response as Server-Sent Events (SSE). |
| 150 | +
|
| 151 | + Event types: |
| 152 | + - token: streamed text token from the LLM |
| 153 | + - tool_call: tool invocation by the agent |
| 154 | + - tool_result: result returned by a tool |
| 155 | + - done: signals the stream is complete |
| 156 | +
|
| 157 | + Args: |
| 158 | + request: ChatRequest containing the user message |
| 159 | + """ |
| 160 | + global agent_graph |
| 161 | + |
| 162 | + if agent_graph is None: |
| 163 | + raise HTTPException(status_code=503, detail="Agent not initialized") |
| 164 | + |
| 165 | + async def event_generator(): |
| 166 | + try: |
| 167 | + messages = [HumanMessage(content=request.message)] |
| 168 | + |
| 169 | + async for event in agent_graph.astream_events( |
| 170 | + {"messages": messages}, |
| 171 | + config={"recursion_limit": 15}, |
| 172 | + version="v2", |
| 173 | + ): |
| 174 | + kind = event["event"] |
| 175 | + |
| 176 | + # LLM streaming tokens |
| 177 | + if kind == "on_chat_model_stream": |
| 178 | + chunk = event["data"]["chunk"] |
| 179 | + if chunk.content: |
| 180 | + yield f"event: token\ndata: {json.dumps({'content': chunk.content})}\n\n" |
| 181 | + |
| 182 | + # Complete tool call (after LLM finishes generating the call) |
| 183 | + elif kind == "on_chat_model_end": |
| 184 | + message = event["data"]["output"] |
| 185 | + if hasattr(message, "tool_calls") and message.tool_calls: |
| 186 | + for tc in message.tool_calls: |
| 187 | + yield f"event: tool_call\ndata: {json.dumps({'name': tc['name'], 'args': tc['args']})}\n\n" |
| 188 | + |
| 189 | + # Tool execution results |
| 190 | + elif kind == "on_tool_end": |
| 191 | + output = event["data"].get("output", "") |
| 192 | + if hasattr(output, "content"): |
| 193 | + output = output.content |
| 194 | + yield f"event: tool_result\ndata: {json.dumps({'name': event.get('name', ''), 'output': str(output)})}\n\n" |
| 195 | + |
| 196 | + yield "event: done\ndata: {}\n\n" |
| 197 | + |
| 198 | + except Exception as e: |
| 199 | + yield f"event: error\ndata: {json.dumps({'detail': str(e)})}\n\n" |
| 200 | + |
| 201 | + return StreamingResponse( |
| 202 | + event_generator(), |
| 203 | + media_type="text/event-stream", |
| 204 | + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, |
| 205 | + ) |
| 206 | + |
| 207 | + |
144 | 208 | @app.get("/health") |
145 | 209 | async def health(): |
146 | 210 | """Return service health and whether the agent graph has been initialized.""" |
|
0 commit comments