Skip to content

Commit 65d490e

Browse files
authored
Merge pull request #3 from cyber-evangelists/dev-branch
Dev branch
2 parents 36308c8 + bb61eb9 commit 65d490e

File tree

12 files changed

+263
-64
lines changed

12 files changed

+263
-64
lines changed

capec-dataset/test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
import pandas as pd
3+
from pathlib import Path
4+
5+
6+
7+
8+
def read_file( file_path: Path) -> pd.DataFrame:
9+
df = pd.read_csv(file_path,
10+
sep=',',
11+
encoding='utf-8',
12+
skipinitialspace=True, index_col=None)
13+
14+
df.columns = df.columns.map(lambda x: x.strip("'\""))
15+
df_reset = df.reset_index(drop=False)
16+
17+
col_names = df.columns
18+
19+
df.columns = col_names
20+
21+
df = df_reset.iloc[:, :-1]
22+
23+
df.columns = col_names
24+
25+
return df
26+
27+
28+
df = read_file("333.csv")
29+
30+
print(df.columns)

client.py

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# client.py
21
import gradio as gr
32
import websockets
43
import json
@@ -117,7 +116,7 @@ async def handle_request(
117116
try:
118117

119118
logger.info("Ensuring Connection....")
120-
await ws_client.ensure_connection()
119+
await self.ensure_connection()
121120

122121
result = await self._handle_websocket_communication(action, payload)
123122
return result
@@ -176,10 +175,9 @@ async def _handle_websocket_communication(
176175
except Exception as e:
177176
logger.error(f"Communication error: {e}")
178177
return "", [(payload.get("query", ""), f"Communication error: {str(e)}")]
179-
180178

181179

182-
# Create WebSocket client instance
180+
183181
ws_client = WebSocketClient(Config.WEBSOCKET_URI)
184182

185183

@@ -216,50 +214,84 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
216214

217215
# Create Gradio interface
218216
with gr.Blocks(
219-
title="CAPEC Chatbot",
220-
theme=gr.themes.Soft(),
221-
css=".gradio-container {max-width: 800px; margin: auto}"
222-
) as demo:
223-
gr.Markdown("""
224-
# ASM Chatbot
225-
Ask questions about CAPEC Dataset and get detailed responses.
226-
""")
227-
228-
with gr.Row():
229-
msg = gr.Textbox(
230-
label="Type your message here...",
231-
placeholder="Enter your query",
232-
show_label=True,
233-
container=True,
234-
scale=8
235-
)
236-
237-
with gr.Row():
238-
search_btn = gr.Button("Search", variant="primary", scale=2)
239-
clear_btn = gr.Button("Clear", variant="secondary", scale=1)
240-
status_box = gr.Textbox(visible=False)
217+
title="Capec Chatbot",
218+
theme=gr.themes.Soft(),
219+
css="""
220+
.gradio-container {
221+
max-width: 700px;
222+
margin: auto;
223+
font-family: Arial, sans-serif;
224+
}
225+
#header {
226+
text-align: center;
227+
font-size: 1.5rem;
228+
font-weight: bold;
229+
color: #008080;
230+
padding: 0.125rem;
231+
}
232+
#input-container {
233+
display: flex;
234+
align-items: center;
235+
background-color: #f7f7f8;
236+
padding: 0.25rem;
237+
border-radius: 8px;
238+
margin-top: 0.25rem;
239+
}
240+
#chatbot {
241+
border: 1px solid #E5E7EB;
242+
border-radius: 8px;
243+
background-color: #FFFFFF;
244+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
245+
}
246+
.gr-button-primary {
247+
background-color: #008080;
248+
border-color: #008080;
249+
}
250+
.gr-button-primary:hover {
251+
background-color: #006666;
252+
}
253+
"""
254+
) as demo:
241255

256+
# Header
257+
gr.Markdown(
258+
"<div id='header'>CAPEC RAG Application</div>"
259+
)
242260

243-
chatbot = gr.Chatbot(
244-
height=400,
245-
show_label=False,
246-
container=True,
247-
elem_id="chatbot"
248-
)
261+
# Chatbot Component
262+
chatbot = gr.Chatbot(
263+
height=450,
264+
show_label=False,
265+
container=True,
266+
elem_id="chatbot"
267+
)
249268

250-
search_btn.click(
251-
fn=search_click,
252-
inputs=[msg, chatbot],
253-
outputs=[msg, chatbot]
254-
)
269+
# Chat Input Row
270+
with gr.Row(elem_id="input-container"):
271+
msg = gr.Textbox(
272+
placeholder="Type a message...",
273+
show_label=False,
274+
container=False,
275+
lines=1,
276+
scale=10,
277+
)
278+
send_button = gr.Button("Send", variant="primary", scale=1)
279+
clear_button = gr.Button("Clear Chat", variant="secondary")
280+
281+
# Button Functionality
282+
send_button.click(
283+
fn=search_click,
284+
inputs=[msg, chatbot],
285+
outputs=[msg, chatbot]
286+
)
287+
clear_button.click(
288+
fn=clear_chat,
289+
inputs=[],
290+
outputs=[chatbot]
291+
)
255292

256-
clear_btn.click(
257-
fn=clear_chat,
258-
inputs=[],
259-
outputs=[chatbot]
260-
)
261293

262-
294+
263295

264296
if __name__ == "__main__":
265297
server_name = Config.GRADIO_SERVER_NAME
@@ -269,4 +301,4 @@ def clear_chat() -> Optional[List[Tuple[str, str]]]:
269301
server_port=server_port,
270302
share=False,
271303
debug=True,
272-
show_error=True,)
304+
show_error=True,)

server.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
22
from loguru import logger
3+
from src.utils.utils import find_file_names
4+
from llama_index.core.vector_stores.types import MetadataFilters, ExactMatchFilter
35

46

57
import asyncio
@@ -22,6 +24,8 @@
2224
from src.utils.connections_manager import ConnectionManager
2325
from src.config.config import Config
2426

27+
from llama_index.core import StorageContext
28+
2529
import os
2630

2731
app = FastAPI()
@@ -35,13 +39,23 @@
3539

3640
embedding_client = EmbeddingWrapper()
3741

42+
# data_dir = Config.CAPEC_DATA_DIR
43+
44+
# storage_context = StorageContext.from_defaults()
45+
46+
# csvParser = CsvParser(data_dir)
47+
# documents = csvParser.process_directory()
48+
# index = qdrantManager.create_and_persist_index(documents, storage_context, embedding_client, Config.PERSIST_DIR)
49+
3850
index = qdrantManager.load_index(persist_dir=Config.PERSIST_DIR, embed_model=embedding_client)
3951

4052
retriever = VectorIndexRetriever(
4153
index=index,
4254
similarity_top_k=5
4355
)
4456

57+
# Manually added file names of the CAPEC daatset. In production, These files will be fetched from database
58+
database_files = ["333.csv", "658.csv", "659.csv", "1000.csv", "3000.csv"]
4559

4660
# Create the connection manager instance
4761
connection_manager = ConnectionManager(max_connections=Config.MAX_CONNECTIONS)
@@ -65,13 +79,23 @@ async def handle_search(websocket: WebSocket, query: str) -> None:
6579
try:
6680
logger.info(f"Processing search query: {query}")
6781

68-
# Generate embeddings
69-
logger.info("Retrieving Relevant nodes")
70-
relevant_nodes = retriever.retrieve(query)
82+
filename = find_file_names(query, database_files)
83+
84+
if filename:
85+
logger.info("Searching for file names...")
86+
87+
filters = MetadataFilters(filters=[ExactMatchFilter(key="source_file", value=filename)])
88+
relevant_nodes = index.as_retriever(filters=filters).retrieve(query)
89+
if not relevant_nodes:
90+
logger.info("Searching without file name filter....")
91+
relevant_nodes = retriever.retrieve(query)
92+
else:
93+
logger.info("Searching without file names....")
94+
relevant_nodes = retriever.retrieve(query)
7195

7296
context = [node.text for node in relevant_nodes]
7397

74-
# Only attaching top 2 results
98+
logger.info(context[:2])
7599
prompt = prepare_prompt(query, context[:2])
76100

77101
# Generate response using Groq
@@ -130,3 +154,5 @@ async def websocket_endpoint(websocket: WebSocket) -> None:
130154
finally:
131155
connection_manager.disconnect(websocket)
132156

157+
158+
-57 MB
Binary file not shown.

src/index/index/default__vector_store.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

src/index/index/docstore.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

src/index/index/index_store.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

src/index/index/metadata.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"saved_at": "2024-11-08T13:47:52.662402",
3+
"index_name": "CAPEC-INDEX",
4+
"num_nodes": 1794
5+
}

src/parser/csv_parser.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
from datetime import datetime
1010
from dataclasses import dataclass
1111

12-
1312
from loguru import logger
1413

1514
from src.config.config import Config
1615

17-
1816
@dataclass
1917
class DocumentMetadata:
2018
"""Class to store document metadata"""
@@ -49,13 +47,34 @@ def create_document_metadata(self, row: pd.Series, file_name: str,) -> DocumentM
4947
)
5048

5149

50+
51+
def read_file(self, file_path: Path) -> pd.DataFrame:
52+
df = pd.read_csv(file_path,
53+
sep=',',
54+
encoding='utf-8',
55+
skipinitialspace=True, index_col=None)
56+
57+
df.columns = df.columns.map(lambda x: x.strip("'\""))
58+
df_reset = df.reset_index(drop=False)
59+
60+
col_names = df.columns
61+
62+
df.columns = col_names
63+
64+
df = df_reset.iloc[:, :-1]
65+
66+
df.columns = col_names
67+
68+
return df
69+
70+
5271
def process_file(self, file_path: Path) -> List[Document]:
5372
"""Process a single CSV file with enhanced metadata and version control"""
5473
try:
5574
logger.info(f"Processing file: {file_path}")
5675

5776
# Read CSV file
58-
df = pd.read_csv(file_path)
77+
df = self.read_file(file_path)
5978

6079
documents = []
6180
for _, row in df.iterrows():
@@ -98,11 +117,11 @@ def get_text(self, row: pd.Series) -> str:
98117
text_parts = []
99118

100119
# Process each column in the row
101-
for col in row.index:
102-
cleaned_text = str(row[col]).strip() if pd.notna(row[col]) else ""
120+
for col, value in row.items(): # Change here to access both col and value
121+
cleaned_text = str(value).strip() if pd.notna(value) else ""
103122
if cleaned_text: # Only include non-empty values
104123
text_parts.append(f"{col}: {cleaned_text}")
105-
124+
106125
# Join all parts with a separator
107126
return " | ".join(text_parts)
108127

0 commit comments

Comments
 (0)