77from loguru import logger
88
99from src .config .config import Config
10-
11-
12-
13- class WebSocketClient :
14- def __init__ (self , uri : str = "ws://rag-server:8000/ws" ):
15- self .uri = uri
16- self .websocket : Optional [websockets .WebSocketClientProtocol ] = None
17- self ._connection_lock = asyncio .Lock ()
18-
19- async def connect (self ):
20- if not self .websocket :
21- self .websocket = await websockets .connect (self .uri )
22- logger .info ("Connected to WebSocket server" )
23- return self .websocket
24-
25- async def disconnect (self ):
26- if self .websocket :
27- await self .websocket .close ()
28- self .websocket = None
29- logger .info ("Disconnected from WebSocket server" )
30-
31- async def send_search_query (self , query : str , payload :str ) -> str :
32- try :
33- # Ensure connection is established
34- if not self .websocket :
35- await self .connect ()
36-
37- # Send search query
38- await self .websocket .send (json .dumps ({"query" : query }))
39-
40- # Wait for response
41- response = await self .websocket .recv ()
42- data = json .loads (response )
43-
44- if data ["status" ] == "success" :
45- # Format results for display
46- results = data ["results" ]
47- formatted_results = "\n \n " .join ([
48- f"Title: { result ['title' ]} \n URL: { result ['url' ]} "
49- for result in results
50- ])
51- return formatted_results , payload
52- else :
53- return f"Error: { data ['message' ], payload } "
54-
55- except websockets .exceptions .ConnectionClosedError :
56- logger .error ("Connection closed unexpectedly. Attempting to reconnect..." )
57- self .websocket = None
58- return "Connection lost. Please try again."
59- except Exception as e :
60- logger .error (f"Error during search: { str (e )} " )
61- return f"Error: { str (e )} "
62-
63- async def ensure_connection (self ):
64- if self .websocket is None or self .websocket .closed :
65- await self .connect ()
66-
67-
68- async def connect (self ):
69- """Establish WebSocket connection and start heartbeat monitoring"""
70- async with self ._connection_lock :
71- try :
72- uri = Config .WEBSOCKET_URI
73- if not uri .startswith (('ws://' , 'wss://' )):
74- logger .error ("Invalid WebSocket URI format" )
75- return
76-
77- self .websocket = await websockets .connect (
78- uri ,
79- ping_interval = 20 ,
80- ping_timeout = 60 ,
81- max_size = 10_485_760
82- )
83- logger .info ("Connected to server....." )
84-
85- # self._heartbeat_task = asyncio.create_task(self._heartbeat())
86-
87- return True
88-
89- except Exception as e :
90- logger .error (f"Connection error: { e } " )
91- return False
92-
93- async def handle_request (
94- self , action : str , payload : dict = {}
95- ) -> Tuple [str , List [Tuple [str , str ]]]:
96- """
97- Handle WebSocket requests to the server.
98-
99- Args:
100- action (str): The action to perform (e.g., 'search', 'ingest_data').
101- payload (dict): The payload containing request data.
102-
103- Returns:
104- Tuple[str, List[Tuple[str, str]]]: A tuple containing the response message
105- and updated chat history.
106- """
107-
108- logger .info ("Into handle search function.." )
109-
110- if action == "search" :
111- query = payload ["query" ]
112- if not query .strip ():
113- logger .error (f"No input provided" )
114- return "" , [(payload .get ("query" , "" ), "No query Entered" )]
115-
116- try :
117-
118- logger .info ("Ensuring Connection...." )
119- await self .ensure_connection ()
120-
121- result = await self ._handle_websocket_communication (action , payload )
122- return result
123-
124- except Exception as e :
125- logger .error (f"Connection error: { e } " )
126- await self .disconnect ()
127- return "" , [(payload .get ("query" , "" ), f"Connection error: { str (e )} " )]
128-
129-
130- async def _handle_websocket_communication (
131- self , action : str , payload : dict
132- ) -> Tuple [str , List [Tuple [str , str ]]]:
133- """
134- Handle the WebSocket communication with the server.
135-
136- Args:
137- action (str): The action to perform.
138- payload (dict): The payload containing request data.
139-
140- Returns:
141- Tuple[str, List[Tuple[str, str]]]: A tuple containing the response message
142- and updated chat history.
143- """
144- try :
145- await self .websocket .send (json .dumps ({"action" : action , "payload" : payload }))
146-
147- while True :
148- response = await self .websocket .recv ()
149- response_data = json .loads (response )
150-
151- logger .info ("Response received..." )
152-
153- # Handle heartbeat
154- if response_data .get ("type" ) == "ping" :
155- await self .websocket .send (json .dumps ({
156- "action" : "pong" ,
157- "timestamp" : response_data .get ("timestamp" )
158- }))
159- continue
160-
161- result = response_data .get ("result" , "No response from server" )
162- if result :
163- if action == "search" :
164- history = payload .get ("history" , [])
165- new_message = (payload .get ("query" , "" ), result )
166- updated_history = history + [new_message ]
167- return "" , updated_history
168- elif action == "ingest_data" :
169- return result , []
170-
171- error = response_data .get ("error" )
172- if error :
173- return "" , [(payload .get ("query" , "" ), f"Error: { error } " )]
174-
175- except Exception as e :
176- logger .error (f"Communication error: { e } " )
177- return "" , [(payload .get ("query" , "" ), f"Communication error: { str (e )} " )]
178-
10+ from src .websocket .web_socket_client import WebSocketClient
17911
18012
18113ws_client = WebSocketClient (Config .WEBSOCKET_URI )
@@ -212,46 +44,56 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
21244 return None
21345
21446
215- # Create Gradio interface
21647with gr .Blocks (
217- title = "Capec Chatbot" ,
218- theme = gr .themes .Soft (),
219- css = """
220- .gradio-container {
221- max-width: 700px;
222- margin: auto;
223- font-family: Arial, sans-serif;
224- }
225- #header {
226- text-align: center;
227- font-size: 1.5rem;
228- font-weight: bold;
229- color: #008080;
230- padding: 0.125rem;
231- }
232- #input-container {
233- display: flex;
234- align-items: center;
235- background-color: #f7f7f8;
236- padding: 0.25rem;
237- border-radius: 8px;
238- margin-top: 0.25rem;
239- }
240- #chatbot {
241- border: 1px solid #E5E7EB;
242- border-radius: 8px;
243- background-color: #FFFFFF;
244- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
245- }
246- .gr-button-primary {
247- background-color: #008080;
248- border-color: #008080;
249- }
250- .gr-button-primary:hover {
251- background-color: #006666;
252- }
253- """
254- ) as demo :
48+ title = "CAPEC RAG Chatbot" ,
49+ theme = gr .themes .Soft (),
50+ css = """
51+ .gradio-container {
52+ max-width: 700px;
53+ margin: auto;
54+ font-family: Arial, sans-serif;
55+ display: flex;
56+ flex-direction: column;
57+ height: 100vh;
58+ }
59+ #header {
60+ text-align: center;
61+ font-size: 1.5rem;
62+ font-weight: bold;
63+ color: #008080;
64+ padding: 0.125rem;
65+ flex: 0 0 auto;
66+ }
67+ #input-container {
68+ display: flex;
69+ align-items: center;
70+ background-color: #f7f7f8;
71+ padding: 0.25rem;
72+ border-radius: 8px;
73+ margin-top: 0.25rem;
74+ flex: 0 0 auto;
75+ }
76+ #chatbot {
77+ border: 1px solid #E5E7EB;
78+ border-radius: 8px;
79+ background-color: #FFFFFF;
80+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
81+ flex: 1 1 auto;
82+ min-height: 0;
83+ display: flex;
84+ flex-direction: column;
85+ overflow-y: auto; /* To allow scrolling if content overflows */
86+ min-height: 72vh;
87+ }
88+ .gr-button-primary {
89+ background-color: #008080;
90+ border-color: #008080;
91+ }
92+ .gr-button-primary:hover {
93+ background-color: #006666;
94+ }
95+ """
96+ ) as demo :
25597
25698 # Header
25799 gr .Markdown (
@@ -260,7 +102,6 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
260102
261103 # Chatbot Component
262104 chatbot = gr .Chatbot (
263- height = 450 ,
264105 show_label = False ,
265106 container = True ,
266107 elem_id = "chatbot"
@@ -277,6 +118,7 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
277118 )
278119 send_button = gr .Button ("Send" , variant = "primary" , scale = 1 )
279120 clear_button = gr .Button ("Clear Chat" , variant = "secondary" )
121+
280122
281123 # Button Functionality
282124 send_button .click (
@@ -302,3 +144,4 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
302144 share = False ,
303145 debug = True ,
304146 show_error = True ,)
147+
0 commit comments