Skip to content

Commit bb61eb9

Browse files
authored
Update: Updated client code for better UI
1 parent 27a8ae8 commit bb61eb9

File tree

1 file changed

+169
-1
lines changed

1 file changed

+169
-1
lines changed

client.py

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,175 @@
77
from loguru import logger
88

99
from src.config.config import Config
10-
from src.websocket.websocket_client import WebSocketClient
10+
11+
12+
13+
class WebSocketClient:
14+
def __init__(self, uri: str = "ws://rag-server:8000/ws"):
15+
self.uri = uri
16+
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
17+
self._connection_lock = asyncio.Lock()
18+
19+
async def connect(self):
20+
if not self.websocket:
21+
self.websocket = await websockets.connect(self.uri)
22+
logger.info("Connected to WebSocket server")
23+
return self.websocket
24+
25+
async def disconnect(self):
26+
if self.websocket:
27+
await self.websocket.close()
28+
self.websocket = None
29+
logger.info("Disconnected from WebSocket server")
30+
31+
async def send_search_query(self, query: str, payload:str) -> str:
32+
try:
33+
# Ensure connection is established
34+
if not self.websocket:
35+
await self.connect()
36+
37+
# Send search query
38+
await self.websocket.send(json.dumps({"query": query}))
39+
40+
# Wait for response
41+
response = await self.websocket.recv()
42+
data = json.loads(response)
43+
44+
if data["status"] == "success":
45+
# Format results for display
46+
results = data["results"]
47+
formatted_results = "\n\n".join([
48+
f"Title: {result['title']}\nURL: {result['url']}"
49+
for result in results
50+
])
51+
return formatted_results, payload
52+
else:
53+
return f"Error: {data['message'], payload}"
54+
55+
except websockets.exceptions.ConnectionClosedError:
56+
logger.error("Connection closed unexpectedly. Attempting to reconnect...")
57+
self.websocket = None
58+
return "Connection lost. Please try again."
59+
except Exception as e:
60+
logger.error(f"Error during search: {str(e)}")
61+
return f"Error: {str(e)}"
62+
63+
async def ensure_connection(self):
64+
if self.websocket is None or self.websocket.closed:
65+
await self.connect()
66+
67+
68+
async def connect(self):
69+
"""Establish WebSocket connection and start heartbeat monitoring"""
70+
async with self._connection_lock:
71+
try:
72+
uri = Config.WEBSOCKET_URI
73+
if not uri.startswith(('ws://', 'wss://')):
74+
logger.error("Invalid WebSocket URI format")
75+
return
76+
77+
self.websocket = await websockets.connect(
78+
uri,
79+
ping_interval=20,
80+
ping_timeout=60,
81+
max_size=10_485_760
82+
)
83+
logger.info("Connected to server.....")
84+
85+
# self._heartbeat_task = asyncio.create_task(self._heartbeat())
86+
87+
return True
88+
89+
except Exception as e:
90+
logger.error(f"Connection error: {e}")
91+
return False
92+
93+
async def handle_request(
94+
self, action: str, payload: dict = {}
95+
) -> Tuple[str, List[Tuple[str, str]]]:
96+
"""
97+
Handle WebSocket requests to the server.
98+
99+
Args:
100+
action (str): The action to perform (e.g., 'search', 'ingest_data').
101+
payload (dict): The payload containing request data.
102+
103+
Returns:
104+
Tuple[str, List[Tuple[str, str]]]: A tuple containing the response message
105+
and updated chat history.
106+
"""
107+
108+
logger.info("Into handle search function..")
109+
110+
if action == "search":
111+
query = payload["query"]
112+
if not query.strip():
113+
logger.error(f"No input provided")
114+
return "", [(payload.get("query", ""), "No query Entered")]
115+
116+
try:
117+
118+
logger.info("Ensuring Connection....")
119+
await self.ensure_connection()
120+
121+
result = await self._handle_websocket_communication(action, payload)
122+
return result
123+
124+
except Exception as e:
125+
logger.error(f"Connection error: {e}")
126+
await self.disconnect()
127+
return "", [(payload.get("query", ""), f"Connection error: {str(e)}")]
128+
129+
130+
async def _handle_websocket_communication(
131+
self, action: str, payload: dict
132+
) -> Tuple[str, List[Tuple[str, str]]]:
133+
"""
134+
Handle the WebSocket communication with the server.
135+
136+
Args:
137+
action (str): The action to perform.
138+
payload (dict): The payload containing request data.
139+
140+
Returns:
141+
Tuple[str, List[Tuple[str, str]]]: A tuple containing the response message
142+
and updated chat history.
143+
"""
144+
try:
145+
await self.websocket.send(json.dumps({"action": action, "payload": payload}))
146+
147+
while True:
148+
response = await self.websocket.recv()
149+
response_data = json.loads(response)
150+
151+
logger.info("Response received...")
152+
153+
# Handle heartbeat
154+
if response_data.get("type") == "ping":
155+
await self.websocket.send(json.dumps({
156+
"action": "pong",
157+
"timestamp": response_data.get("timestamp")
158+
}))
159+
continue
160+
161+
result = response_data.get("result", "No response from server")
162+
if result:
163+
if action == "search":
164+
history = payload.get("history", [])
165+
new_message = (payload.get("query", ""), result)
166+
updated_history = history + [new_message]
167+
return "", updated_history
168+
elif action == "ingest_data":
169+
return result, []
170+
171+
error = response_data.get("error")
172+
if error:
173+
return "", [(payload.get("query", ""), f"Error: {error}")]
174+
175+
except Exception as e:
176+
logger.error(f"Communication error: {e}")
177+
return "", [(payload.get("query", ""), f"Communication error: {str(e)}")]
178+
11179

12180

13181
ws_client = WebSocketClient(Config.WEBSOCKET_URI)

0 commit comments

Comments
 (0)