-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsocket_server.py
More file actions
214 lines (186 loc) · 8.3 KB
/
socket_server.py
File metadata and controls
214 lines (186 loc) · 8.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import json
import uuid
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query
from fastapi.middleware.cors import CORSMiddleware
from backend import chatbot, retrieve_all_threads, describe_image
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from starlette.websockets import WebSocketState
# ---------- Utilities ----------
def generate_thread_id():
return str(uuid.uuid4())
executor = ThreadPoolExecutor(max_workers=2) # For running synchronous stream
# ---------- FastAPI App ----------
app = FastAPI()
# Allow CORS for Next.js frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------- Stream Helper ----------
async def send_chunk(ws: WebSocket, chunk):
"""Send a single chunk/message to WebSocket safely"""
if ws.application_state != WebSocketState.CONNECTED:
return
if isinstance(chunk, ToolMessage):
await safe_send(ws, json.dumps({
"type": "tool_message",
"tool": getattr(chunk, "name", "tool"),
"content": chunk.content
}))
elif isinstance(chunk, AIMessage):
await safe_send(ws, json.dumps({
"type": "assistant_chunk",
"content": chunk.content
}))
async def safe_send(ws, msg):
if ws.application_state == WebSocketState.CONNECTED:
try:
await ws.send_text(msg)
except:
pass # ignore if already closed
async def stream_to_ws(websocket: WebSocket, user_input: str, thread_id: str):
"""Run LangGraph stream in executor and push chunks to WebSocket"""
loop = asyncio.get_event_loop()
def run_stream():
for msg_chunk, metadata in chatbot.stream(
{"messages": [HumanMessage(content=user_input)]},
config={
"configurable": {"thread_id": thread_id},
"metadata": {"thread_id": thread_id},
"run_name": "chat_turn",
},
stream_mode="messages",
):
asyncio.run_coroutine_threadsafe(send_chunk(websocket, msg_chunk), loop)
# Run synchronous stream in thread
await loop.run_in_executor(executor, run_stream)
# Send final message
state = chatbot.get_state(config={"configurable": {"thread_id": thread_id}})
messages = state.values.get("messages", [])
final_msg = next((m.content for m in reversed(messages) if isinstance(m, AIMessage)), "")
await safe_send(websocket,json.dumps({"type": "assistant_final", "content": final_msg}))
def serialize_messages(messages):
"""Serialize messages for frontend"""
serialized = []
for m in messages:
if isinstance(m, HumanMessage):
serialized.append({"role": "user", "content": m.content})
elif isinstance(m, AIMessage):
serialized.append({"role": "assistant", "content": m.content})
elif isinstance(m, ToolMessage):
serialized.append({"role": "tool", "content": m.content, "tool": getattr(m, "name", "tool")})
return serialized
# ---------- WebSocket Endpoint ----------
@app.websocket("/ws/chat")
async def chat_endpoint(websocket: WebSocket, thread_id: Optional[str] = Query(None)):
await websocket.accept()
if not thread_id:
thread_id = generate_thread_id()
await safe_send(websocket,json.dumps({"type": "session_create", "thread_id": thread_id}))
else:
# Send existing thread messages when connecting to existing thread
try:
state = chatbot.get_state(config={"configurable": {"thread_id": thread_id}})
messages = state.values.get("messages", [])
await safe_send(websocket,json.dumps({
"type": "thread_messages",
"thread_id": thread_id,
"messages": serialize_messages(messages)
}))
except Exception as e:
print(f"Error loading thread {thread_id}: {e}")
# If thread doesn't exist, create new one
thread_id = generate_thread_id()
await safe_send(websocket,json.dumps({"type": "session_create", "thread_id": thread_id}))
try:
while True:
msg = await websocket.receive_text()
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "user_message":
content = data.get("content", "")
await safe_send(websocket,json.dumps({"type": "user_ack", "content": content}))
await stream_to_ws(websocket, content, thread_id)
elif msg_type == "get_threads":
threads = retrieve_all_threads()
await safe_send(websocket,json.dumps({"type": "threads_list", "threads": threads}))
elif msg_type == "set_thread":
new_thread_id = data.get("thread_id")
if not new_thread_id:
# Create new thread
new_thread_id = generate_thread_id()
thread_id = new_thread_id
await safe_send(websocket,json.dumps({
"type": "session_create",
"thread_id": thread_id
}))
# Clear messages for new thread
await safe_send(websocket,json.dumps({
"type": "thread_messages",
"thread_id": thread_id,
"messages": []
}))
else:
# Switch to existing thread
thread_id = new_thread_id
try:
state = chatbot.get_state(config={"configurable": {"thread_id": thread_id}})
messages = state.values.get("messages", [])
await safe_send(websocket,json.dumps({
"type": "thread_set",
"thread_id": thread_id
}))
await safe_send(websocket,json.dumps({
"type": "thread_messages",
"thread_id": thread_id,
"messages": serialize_messages(messages)
}))
except Exception as e:
await safe_send(websocket,json.dumps({
"type": "error",
"message": f"Thread not found: {e}"
}))
elif msg_type == "fetch_thread":
requested_thread = data.get("thread_id")
try:
state = chatbot.get_state(config={"configurable": {"thread_id": requested_thread}})
messages = state.values.get("messages", [])
await safe_send(websocket,json.dumps({
"type": "thread_messages",
"thread_id": requested_thread,
"messages": serialize_messages(messages)
}))
except Exception as e:
await safe_send(websocket,json.dumps({
"type": "error",
"message": f"Error fetching thread: {e}"
}))
elif msg_type == "uploading_file":
file_base64 = data.get("file") # Base64 from frontend
# Acknowledge
await safe_send(websocket,json.dumps({
"type": "upload_ack",
"message": "Image received, analyzing..."
}))
# Call vision tool directly
result = describe_image.invoke({"image_base64": file_base64})
# Feed analysis into conversation (memory saved)
await stream_to_ws(
websocket,
f"User uploaded an image. Here is the analysis: {result}",
thread_id,
)
except WebSocketDisconnect:
print("Client disconnected.")
except Exception as e:
await safe_send(websocket,json.dumps({"type": "error", "message": str(e)}))
# ---------- Run ----------
if __name__ == "__main__":
import uvicorn
uvicorn.run("socket_server:app", host="0.0.0.0", port=8001, reload=True)