11import json
22import logging
3+ import time
4+ import uuid
35from contextlib import asynccontextmanager
46from os import getenv
57
1315from langgraph_react_agent_base .agent import get_graph_closure
1416
1517
16- # Request/Response models
17- class ChatRequest (BaseModel ):
18- """Incoming chat request body for the /chat endpoint."""
18+ # OpenAI-compatible request/response models
19+ class ChatMessage (BaseModel ):
20+ role : str
21+ content : str
1922
20- message : str
2123
24+ class ChatCompletionRequest (BaseModel ):
25+ """OpenAI-compatible chat completion request."""
2226
23- class ChatResponse (BaseModel ):
24- """Structured chat response (answer and optional steps)."""
25-
26- answer : str
27- steps : list [str ]
27+ messages : list [ChatMessage ]
28+ model : str | None = None
29+ stream : bool = False
2830
2931
3032# Global variable for agent graph
@@ -33,27 +35,19 @@ class ChatResponse(BaseModel):
3335
3436@asynccontextmanager
3537async def lifespan (app : FastAPI ):
36- """Initialize the ReAct agent graph on startup and clear it on shutdown.
37-
38- Reads BASE_URL and MODEL_ID from the environment, builds the graph via
39- get_graph_closure, and sets the global agent_graph for the /chat endpoint.
40- """
38+ """Initialize the ReAct agent graph on startup and clear it on shutdown."""
4139 global agent_graph
4240
43- # Get environment variables
4441 base_url = getenv ("BASE_URL" )
4542 model_id = getenv ("MODEL_ID" )
4643
47- # Ensure base_url ends with /v1 if provided
4844 if base_url and not base_url .endswith ("/v1" ):
4945 base_url = base_url .rstrip ("/" ) + "/v1"
5046
51- # Get graph closure and create agent graph
5247 agent_graph = get_graph_closure (model_id = model_id , base_url = base_url )
5348
5449 yield
5550
56- # Cleanup on shutdown (if needed)
5751 agent_graph = None
5852
5953
@@ -65,49 +59,61 @@ async def lifespan(app: FastAPI):
6559)
6660
6761
68- @app .post ("/chat" )
69- async def chat (request : ChatRequest ):
70- """
71- Chat endpoint that accepts a message and returns the agent's response.
62+ def _build_langchain_messages (messages : list [ChatMessage ]) -> list [HumanMessage ]:
63+ """Extract the last user message from the OpenAI-format messages list."""
64+ for msg in reversed (messages ):
65+ if msg .role == "user" :
66+ return [HumanMessage (content = msg .content )]
67+ raise ValueError ("No user message found in messages list" )
68+
69+
70+ def _make_completion_id () -> str :
71+ return f"chatcmpl-{ uuid .uuid4 ().hex [:12 ]} "
7272
73- Args:
74- request: ChatRequest containing the user message
7573
76- Returns:
77- JSON response with full conversation history including tool calls
74+ @app .post ("/chat/completions" )
75+ async def chat_completions (request : ChatCompletionRequest ):
76+ """
77+ OpenAI-compatible chat completions endpoint.
78+
79+ When stream=false, returns a full chat.completion response.
80+ When stream=true, returns SSE chat.completion.chunk events.
7881 """
7982 global agent_graph
8083
8184 if agent_graph is None :
8285 raise HTTPException (status_code = 503 , detail = "Agent not initialized" )
8386
84- try :
85- messages = [ HumanMessage ( content = request .message )]
87+ langchain_messages = _build_langchain_messages ( request . messages )
88+ model_id = request .model or getenv ( "MODEL_ID" , "model" )
8689
87- # Use invoke to get the agent's response
90+ if request .stream :
91+ return await _handle_stream (langchain_messages , model_id )
92+ else :
93+ return await _handle_chat (langchain_messages , model_id )
94+
95+
96+ async def _handle_chat (messages : list [HumanMessage ], model_id : str ):
97+ """Handle non-streaming chat completion."""
98+ global agent_graph
99+
100+ try :
88101 result = await agent_graph .ainvoke (
89102 {"messages" : messages }, config = {"recursion_limit" : 10 }
90103 )
91104
92- response_messages = []
105+ # Extract the final assistant message content
106+ assistant_content = ""
107+ context_messages = []
93108
94109 if "messages" in result and len (result ["messages" ]) > 0 :
95110 for message in result ["messages" ]:
96- # 1. User message (HumanMessage)
97111 if isinstance (message , HumanMessage ):
98- response_messages .append (
99- {
100- "role" : "user" ,
101- "content" : message .content ,
102- }
112+ context_messages .append (
113+ {"role" : "user" , "content" : message .content }
103114 )
104-
105- # 2. AI message (AIMessage)
106115 elif isinstance (message , AIMessage ):
107- msg_data = {
108- "role" : "assistant" ,
109- "content" : message .content or "" ,
110- }
116+ msg_data = {"role" : "assistant" , "content" : message .content or "" }
111117 if message .tool_calls :
112118 msg_data ["tool_calls" ] = [
113119 {
@@ -120,11 +126,9 @@ async def chat(request: ChatRequest):
120126 }
121127 for tc in message .tool_calls
122128 ]
123- response_messages .append (msg_data )
124-
125- # 3. Tool response (ToolMessage)
129+ context_messages .append (msg_data )
126130 elif isinstance (message , ToolMessage ):
127- response_messages .append (
131+ context_messages .append (
128132 {
129133 "role" : "tool" ,
130134 "tool_call_id" : message .tool_call_id ,
@@ -133,38 +137,46 @@ async def chat(request: ChatRequest):
133137 }
134138 )
135139
136- return {"messages" : response_messages , "finish_reason" : "stop" }
140+ # Final assistant content is the last AIMessage with content
141+ for message in reversed (result ["messages" ]):
142+ if isinstance (message , AIMessage ) and message .content :
143+ assistant_content = message .content
144+ break
145+
146+ return {
147+ "id" : _make_completion_id (),
148+ "object" : "chat.completion" ,
149+ "created" : int (time .time ()),
150+ "model" : model_id ,
151+ "choices" : [
152+ {
153+ "index" : 0 ,
154+ "message" : {
155+ "role" : "assistant" ,
156+ "content" : assistant_content ,
157+ },
158+ "finish_reason" : "stop" ,
159+ }
160+ ],
161+ "context" : context_messages ,
162+ "usage" : None ,
163+ }
137164
138165 except Exception as e :
139166 raise HTTPException (
140167 status_code = 500 , detail = f"Error processing request: { str (e )} "
141168 )
142169
143170
144- @app .post ("/stream" )
145- async def stream (request : ChatRequest ):
146- """
147- Streaming chat endpoint that accepts a message and returns the agent's
148- response as Server-Sent Events (SSE).
149-
150- Event types:
151- - token: streamed text token from the LLM
152- - tool_call: tool invocation by the agent
153- - tool_result: result returned by a tool
154- - done: signals the stream is complete
155-
156- Args:
157- request: ChatRequest containing the user message
158- """
171+ async def _handle_stream (messages : list [HumanMessage ], model_id : str ):
172+ """Handle streaming chat completion with OpenAI-compatible SSE chunks."""
159173 global agent_graph
160174
161- if agent_graph is None :
162- raise HTTPException ( status_code = 503 , detail = "Agent not initialized" )
175+ completion_id = _make_completion_id ()
176+ created = int ( time . time () )
163177
164178 async def event_generator ():
165179 try :
166- messages = [HumanMessage (content = request .message )]
167-
168180 async for event in agent_graph .astream_events (
169181 {"messages" : messages },
170182 config = {"recursion_limit" : 10 },
@@ -176,28 +188,105 @@ async def event_generator():
176188 if kind == "on_chat_model_stream" :
177189 chunk = event ["data" ]["chunk" ]
178190 if chunk .content :
179- yield f"event: token\n data: { json .dumps ({'content' : chunk .content })} \n \n "
191+ data = {
192+ "id" : completion_id ,
193+ "object" : "chat.completion.chunk" ,
194+ "created" : created ,
195+ "model" : model_id ,
196+ "choices" : [
197+ {
198+ "index" : 0 ,
199+ "delta" : {"content" : chunk .content },
200+ "finish_reason" : None ,
201+ }
202+ ],
203+ }
204+ yield f"data: { json .dumps (data )} \n \n "
180205
181- # Complete tool call (after LLM finishes generating the call)
206+ # Tool calls (after LLM finishes generating the call)
182207 elif kind == "on_chat_model_end" :
183208 message = event ["data" ]["output" ]
184209 if hasattr (message , "tool_calls" ) and message .tool_calls :
185- for tc in message .tool_calls :
186- yield f"event: tool_call\n data: { json .dumps ({'name' : tc ['name' ], 'args' : tc ['args' ]})} \n \n "
210+ tool_calls_delta = [
211+ {
212+ "index" : i ,
213+ "id" : tc ["id" ],
214+ "type" : "function" ,
215+ "function" : {
216+ "name" : tc ["name" ],
217+ "arguments" : json .dumps (tc ["args" ]),
218+ },
219+ }
220+ for i , tc in enumerate (message .tool_calls )
221+ ]
222+ data = {
223+ "id" : completion_id ,
224+ "object" : "chat.completion.chunk" ,
225+ "created" : created ,
226+ "model" : model_id ,
227+ "choices" : [
228+ {
229+ "index" : 0 ,
230+ "delta" : {
231+ "role" : "assistant" ,
232+ "tool_calls" : tool_calls_delta ,
233+ },
234+ "finish_reason" : None ,
235+ }
236+ ],
237+ }
238+ yield f"data: { json .dumps (data )} \n \n "
187239
188240 # Tool execution results
189241 elif kind == "on_tool_end" :
190242 output = event ["data" ].get ("output" , "" )
191- # Extract content from ToolMessage if present
192243 if hasattr (output , "content" ):
193244 output = output .content
194- yield f"event: tool_result\n data: { json .dumps ({'name' : event .get ('name' , '' ), 'output' : str (output )})} \n \n "
195-
196- yield "event: done\n data: {}\n \n "
245+ data = {
246+ "id" : completion_id ,
247+ "object" : "chat.completion.chunk" ,
248+ "created" : created ,
249+ "model" : model_id ,
250+ "choices" : [
251+ {
252+ "index" : 0 ,
253+ "delta" : {
254+ "role" : "tool" ,
255+ "content" : str (output ),
256+ "name" : event .get ("name" , "" ),
257+ },
258+ "finish_reason" : None ,
259+ }
260+ ],
261+ }
262+ yield f"data: { json .dumps (data )} \n \n "
263+
264+ # Send final chunk with finish_reason
265+ final_data = {
266+ "id" : completion_id ,
267+ "object" : "chat.completion.chunk" ,
268+ "created" : created ,
269+ "model" : model_id ,
270+ "choices" : [
271+ {
272+ "index" : 0 ,
273+ "delta" : {},
274+ "finish_reason" : "stop" ,
275+ }
276+ ],
277+ }
278+ yield f"data: { json .dumps (final_data )} \n \n "
279+ yield "data: [DONE]\n \n "
197280
198- except Exception as e :
281+ except Exception :
199282 logger .exception ("Error in stream event_generator" )
200- yield f"event: error\n data: { json .dumps ({'detail' : 'Internal server error' })} \n \n "
283+ error_data = {
284+ "error" : {
285+ "message" : "Internal server error" ,
286+ "type" : "server_error" ,
287+ }
288+ }
289+ yield f"data: { json .dumps (error_data )} \n \n "
201290
202291 return StreamingResponse (
203292 event_generator (),
0 commit comments