Skip to content

Commit bd20a79

Browse files
authored
Merge pull request #4 from cyber-evangelists/dev-branch
Dev branch
2 parents d1b56d9 + 1204d4a commit bd20a79

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+426
-63593
lines changed

Dockerfile

Lines changed: 0 additions & 26 deletions
This file was deleted.

capec-dataset/1000.csv

Lines changed: 0 additions & 560 deletions
This file was deleted.

capec-dataset/3000.csv

Lines changed: 0 additions & 560 deletions
This file was deleted.

capec-dataset/333.csv

Lines changed: 0 additions & 37 deletions
This file was deleted.

capec-dataset/658.csv

Lines changed: 0 additions & 178 deletions
This file was deleted.

capec-dataset/659.csv

Lines changed: 0 additions & 39 deletions
This file was deleted.

capec-dataset/test.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

client.py

Lines changed: 52 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -7,175 +7,7 @@
77
from loguru import logger
88

99
from src.config.config import Config
10-
11-
12-
13-
class WebSocketClient:
14-
def __init__(self, uri: str = "ws://rag-server:8000/ws"):
15-
self.uri = uri
16-
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
17-
self._connection_lock = asyncio.Lock()
18-
19-
async def connect(self):
20-
if not self.websocket:
21-
self.websocket = await websockets.connect(self.uri)
22-
logger.info("Connected to WebSocket server")
23-
return self.websocket
24-
25-
async def disconnect(self):
26-
if self.websocket:
27-
await self.websocket.close()
28-
self.websocket = None
29-
logger.info("Disconnected from WebSocket server")
30-
31-
async def send_search_query(self, query: str, payload:str) -> str:
32-
try:
33-
# Ensure connection is established
34-
if not self.websocket:
35-
await self.connect()
36-
37-
# Send search query
38-
await self.websocket.send(json.dumps({"query": query}))
39-
40-
# Wait for response
41-
response = await self.websocket.recv()
42-
data = json.loads(response)
43-
44-
if data["status"] == "success":
45-
# Format results for display
46-
results = data["results"]
47-
formatted_results = "\n\n".join([
48-
f"Title: {result['title']}\nURL: {result['url']}"
49-
for result in results
50-
])
51-
return formatted_results, payload
52-
else:
53-
return f"Error: {data['message'], payload}"
54-
55-
except websockets.exceptions.ConnectionClosedError:
56-
logger.error("Connection closed unexpectedly. Attempting to reconnect...")
57-
self.websocket = None
58-
return "Connection lost. Please try again."
59-
except Exception as e:
60-
logger.error(f"Error during search: {str(e)}")
61-
return f"Error: {str(e)}"
62-
63-
async def ensure_connection(self):
64-
if self.websocket is None or self.websocket.closed:
65-
await self.connect()
66-
67-
68-
async def connect(self):
69-
"""Establish WebSocket connection and start heartbeat monitoring"""
70-
async with self._connection_lock:
71-
try:
72-
uri = Config.WEBSOCKET_URI
73-
if not uri.startswith(('ws://', 'wss://')):
74-
logger.error("Invalid WebSocket URI format")
75-
return
76-
77-
self.websocket = await websockets.connect(
78-
uri,
79-
ping_interval=20,
80-
ping_timeout=60,
81-
max_size=10_485_760
82-
)
83-
logger.info("Connected to server.....")
84-
85-
# self._heartbeat_task = asyncio.create_task(self._heartbeat())
86-
87-
return True
88-
89-
except Exception as e:
90-
logger.error(f"Connection error: {e}")
91-
return False
92-
93-
async def handle_request(
94-
self, action: str, payload: dict = {}
95-
) -> Tuple[str, List[Tuple[str, str]]]:
96-
"""
97-
Handle WebSocket requests to the server.
98-
99-
Args:
100-
action (str): The action to perform (e.g., 'search', 'ingest_data').
101-
payload (dict): The payload containing request data.
102-
103-
Returns:
104-
Tuple[str, List[Tuple[str, str]]]: A tuple containing the response message
105-
and updated chat history.
106-
"""
107-
108-
logger.info("Into handle search function..")
109-
110-
if action == "search":
111-
query = payload["query"]
112-
if not query.strip():
113-
logger.error(f"No input provided")
114-
return "", [(payload.get("query", ""), "No query Entered")]
115-
116-
try:
117-
118-
logger.info("Ensuring Connection....")
119-
await self.ensure_connection()
120-
121-
result = await self._handle_websocket_communication(action, payload)
122-
return result
123-
124-
except Exception as e:
125-
logger.error(f"Connection error: {e}")
126-
await self.disconnect()
127-
return "", [(payload.get("query", ""), f"Connection error: {str(e)}")]
128-
129-
130-
async def _handle_websocket_communication(
131-
self, action: str, payload: dict
132-
) -> Tuple[str, List[Tuple[str, str]]]:
133-
"""
134-
Handle the WebSocket communication with the server.
135-
136-
Args:
137-
action (str): The action to perform.
138-
payload (dict): The payload containing request data.
139-
140-
Returns:
141-
Tuple[str, List[Tuple[str, str]]]: A tuple containing the response message
142-
and updated chat history.
143-
"""
144-
try:
145-
await self.websocket.send(json.dumps({"action": action, "payload": payload}))
146-
147-
while True:
148-
response = await self.websocket.recv()
149-
response_data = json.loads(response)
150-
151-
logger.info("Response received...")
152-
153-
# Handle heartbeat
154-
if response_data.get("type") == "ping":
155-
await self.websocket.send(json.dumps({
156-
"action": "pong",
157-
"timestamp": response_data.get("timestamp")
158-
}))
159-
continue
160-
161-
result = response_data.get("result", "No response from server")
162-
if result:
163-
if action == "search":
164-
history = payload.get("history", [])
165-
new_message = (payload.get("query", ""), result)
166-
updated_history = history + [new_message]
167-
return "", updated_history
168-
elif action == "ingest_data":
169-
return result, []
170-
171-
error = response_data.get("error")
172-
if error:
173-
return "", [(payload.get("query", ""), f"Error: {error}")]
174-
175-
except Exception as e:
176-
logger.error(f"Communication error: {e}")
177-
return "", [(payload.get("query", ""), f"Communication error: {str(e)}")]
178-
10+
from src.websocket.web_socket_client import WebSocketClient
17911

18012

18113
ws_client = WebSocketClient(Config.WEBSOCKET_URI)
@@ -212,46 +44,56 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
21244
return None
21345

21446

215-
# Create Gradio interface
21647
with gr.Blocks(
217-
title="Capec Chatbot",
218-
theme=gr.themes.Soft(),
219-
css="""
220-
.gradio-container {
221-
max-width: 700px;
222-
margin: auto;
223-
font-family: Arial, sans-serif;
224-
}
225-
#header {
226-
text-align: center;
227-
font-size: 1.5rem;
228-
font-weight: bold;
229-
color: #008080;
230-
padding: 0.125rem;
231-
}
232-
#input-container {
233-
display: flex;
234-
align-items: center;
235-
background-color: #f7f7f8;
236-
padding: 0.25rem;
237-
border-radius: 8px;
238-
margin-top: 0.25rem;
239-
}
240-
#chatbot {
241-
border: 1px solid #E5E7EB;
242-
border-radius: 8px;
243-
background-color: #FFFFFF;
244-
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
245-
}
246-
.gr-button-primary {
247-
background-color: #008080;
248-
border-color: #008080;
249-
}
250-
.gr-button-primary:hover {
251-
background-color: #006666;
252-
}
253-
"""
254-
) as demo:
48+
title="CAPEC RAG Chatbot",
49+
theme=gr.themes.Soft(),
50+
css="""
51+
.gradio-container {
52+
max-width: 700px;
53+
margin: auto;
54+
font-family: Arial, sans-serif;
55+
display: flex;
56+
flex-direction: column;
57+
height: 100vh;
58+
}
59+
#header {
60+
text-align: center;
61+
font-size: 1.5rem;
62+
font-weight: bold;
63+
color: #008080;
64+
padding: 0.125rem;
65+
flex: 0 0 auto;
66+
}
67+
#input-container {
68+
display: flex;
69+
align-items: center;
70+
background-color: #f7f7f8;
71+
padding: 0.25rem;
72+
border-radius: 8px;
73+
margin-top: 0.25rem;
74+
flex: 0 0 auto;
75+
}
76+
#chatbot {
77+
border: 1px solid #E5E7EB;
78+
border-radius: 8px;
79+
background-color: #FFFFFF;
80+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
81+
flex: 1 1 auto;
82+
min-height: 0;
83+
display: flex;
84+
flex-direction: column;
85+
overflow-y: auto; /* To allow scrolling if content overflows */
86+
min-height: 72vh;
87+
}
88+
.gr-button-primary {
89+
background-color: #008080;
90+
border-color: #008080;
91+
}
92+
.gr-button-primary:hover {
93+
background-color: #006666;
94+
}
95+
"""
96+
) as demo:
25597

25698
# Header
25799
gr.Markdown(
@@ -260,7 +102,6 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
260102

261103
# Chatbot Component
262104
chatbot = gr.Chatbot(
263-
height=450,
264105
show_label=False,
265106
container=True,
266107
elem_id="chatbot"
@@ -277,6 +118,7 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
277118
)
278119
send_button = gr.Button("Send", variant="primary", scale=1)
279120
clear_button = gr.Button("Clear Chat", variant="secondary")
121+
280122

281123
# Button Functionality
282124
send_button.click(
@@ -302,3 +144,4 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
302144
share=False,
303145
debug=True,
304146
show_error=True,)
147+

docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ services:
66
server:
77
build:
88
context: .
9-
dockerfile: Dockerfile.server
9+
dockerfile: src/docker-files/Dockerfile.server
1010
ports:
1111
- "8000:8000"
1212
networks:
@@ -20,7 +20,7 @@ services:
2020
client:
2121
build:
2222
context: .
23-
dockerfile: Dockerfile.client
23+
dockerfile: src/docker-files/Dockerfile.client
2424
networks:
2525
- capec-network
2626
ports:

requirements.txt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
# sentence-transformers==3.2.1
1+
fastapi==0.115.5
2+
websockets==12.0
23
qdrant-client==1.12.0
34
python-multipart==0.0.16
45
scikit-learn==1.5.2
5-
langchain-core==0.3.13
6+
langchain-core==0.3.18
67
gradio==4.44.1
78
loguru==0.7.2
8-
groq==0.11.0
99
python-dotenv==1.0.1
10-
fastapi
11-
websockets
12-
# google-auth-oauthlib==1.2.1
13-
# google-auth==2.35.0
14-
# google-api-python-client==2.151.0
1510
llama-index==0.11.21
1611
llama-index-vector-stores-qdrant==0.3.3
17-
llama-index-llms-groq==0.2.0
18-
llama-index-embeddings-huggingface==0.3.1
12+
langchain-groq==0.2.1
13+
llama-index-embeddings-huggingface==0.3.1
14+
langchain==0.3.7

0 commit comments

Comments
 (0)