-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchatbot.py
More file actions
294 lines (237 loc) · 11.1 KB
/
Copy pathchatbot.py
File metadata and controls
294 lines (237 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""
Core chatbot logic for the Financial Data Chatbot.
Handles data loading, LLM-based SQL generation, query execution, and result explanation.
"""
import json
import os
import logging
from typing import Any, Dict, List
import duckdb
import pandas as pd
from openai import OpenAI
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Database setup (module-level singleton)
# ---------------------------------------------------------------------------
conn = duckdb.connect(":memory:")
trades_df = pd.read_csv(os.path.join(os.path.dirname(__file__), "trades.csv"))
conn.register("trades", trades_df)
holdings_df = pd.read_csv(os.path.join(os.path.dirname(__file__), "holdings.csv"))
conn.register("holdings", holdings_df)
trades_funds: List[tuple] = conn.execute(
"SELECT DISTINCT PortfolioName FROM trades ORDER BY PortfolioName"
).fetchall()
holdings_funds: List[tuple] = conn.execute(
"SELECT DISTINCT PortfolioName FROM holdings ORDER BY PortfolioName"
).fetchall()
logger.info(
"Data loaded — %d trade records, %d holdings records",
len(trades_df),
len(holdings_df),
)
# ---------------------------------------------------------------------------
# Schema metadata
# ---------------------------------------------------------------------------
SCHEMA_METADATA: Dict[str, Any] = {
"trades": {
"description": "Transaction records of buy/sell/cover trades executed in various funds",
"columns": {
"id": "Unique trade identifier",
"RevisionId": "Trade revision number",
"TradeTypeName": "Type of trade (Buy, Sell, Sell Short, Buy Fixed/Floating Rate, Cover Short, Buy Protection)",
"SecurityId": "Unique identifier for security/asset",
"SecurityType": "Type of security (Equity, Bond, etc.)",
"Name": "Security name/description",
"Ticker": "Stock/Security ticker symbol",
"TradeDate": "Date trade was executed",
"SettleDate": "Date trade settled",
"Quantity": "Number of shares/units traded",
"Price": "Price per unit at trade",
"Principal": "Total transaction value (Quantity × Price)",
"Interest": "Interest accrued/paid on trade",
"TotalCash": "Total cash flow for trade",
"PortfolioName": "Fund/Portfolio name",
"CustodianName": "Custodian/broker handling the trade",
"StrategyName": "Investment strategy",
"Counterparty": "Entity on other side of trade",
"AllocationRule": "How trade is allocated across sub-portfolios",
},
"key_filters": ["PortfolioName", "TradeTypeName", "SecurityType", "Counterparty"],
},
"holdings": {
"description": "Current positions held in each fund as of a specific date",
"columns": {
"AsOfDate": "Date for which holdings are reported",
"PortfolioName": "Fund/Portfolio name",
"SecurityId": "Unique identifier for security/asset",
"SecurityTypeName": "Type of security (Bond, Equity, etc.)",
"SecName": "Security identifier/ISIN",
"Qty": "Current quantity held",
"Price": "Current price per unit",
"FXRate": "Foreign exchange rate for currency conversion",
"MV_Local": "Market value in local currency (Qty × Price)",
"MV_Base": "Market value in base currency (MV_Local × FXRate)",
"PL_DTD": "Profit/Loss Day-To-Date",
"PL_QTD": "Profit/Loss Quarter-To-Date",
"PL_MTD": "Profit/Loss Month-To-Date",
"PL_YTD": "Profit/Loss Year-To-Date (cumulative gain/loss from start of year)",
"StartQty": "Quantity at start of year",
"StartPrice": "Price at start of year",
"StartFXRate": "FX rate at start of year",
},
"key_filters": ["PortfolioName", "SecurityTypeName", "AsOfDate"],
},
}
# ---------------------------------------------------------------------------
# OpenAI client (lazy initialization so the module can be imported without a key)
# ---------------------------------------------------------------------------
def _get_openai_client() -> OpenAI:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise EnvironmentError("OPENAI_API_KEY environment variable is not set.")
return OpenAI(api_key=api_key)
# ---------------------------------------------------------------------------
# Pipeline steps
# ---------------------------------------------------------------------------
def decide_data_source(user_question: str) -> Dict[str, Any]:
"""Ask the LLM which data source(s) are needed to answer the question."""
client = _get_openai_client()
prompt = f"""You are a financial data expert. Analyze this question and decide which data source(s) to use.
Question: {user_question}
Available sources:
1. TRADES - Contains buy/sell transaction history
2. HOLDINGS - Contains current positions and P&L data
3. BOTH - If you need information from both tables
You have the following funds in TRADES: {trades_funds}
You have the following funds in HOLDINGS: {holdings_funds}
Always try to match user input with these fund names, handling case sensitivity and typos.
If the question is about trades, trade types, trade dates, quantities, prices, or counterparties → use "trades".
If the question is about holdings, quantities held, market values, or profit/loss → use "holdings".
Respond in JSON format ONLY:
{{
"sources": ["trades"] or ["holdings"] or ["trades", "holdings"],
"reasoning": "Brief explanation of why these sources"
}}"""
response = client.chat.completions.create(
model="gpt-4o-mini",
max_tokens=500,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
)
raw = response.choices[0].message.content
try:
return json.loads(raw)
except json.JSONDecodeError:
logger.warning("Failed to parse decide_data_source response: %s", raw)
return {"sources": ["trades", "holdings"], "reasoning": "Default to both sources"}
def generate_sql(user_question: str, sources: List[str]) -> Dict[str, Any]:
"""Generate a DuckDB SQL query for the given question using the chosen sources."""
client = _get_openai_client()
schema_context = ""
if "trades" in sources:
schema_context += f"\nTRADES TABLE:\n{json.dumps(SCHEMA_METADATA['trades'], indent=2)}\n"
if "holdings" in sources:
schema_context += f"\nHOLDINGS TABLE:\n{json.dumps(SCHEMA_METADATA['holdings'], indent=2)}\n"
prompt = f"""You are a SQL expert for financial data. Generate a DuckDB SQL query to answer this question.
{schema_context}
Question: {user_question}
HARD RULES:
1. Use DuckDB SQL syntax only — NO PostgreSQL-specific syntax.
2. NEVER use ILIKE; use LOWER() comparisons instead.
3. NEVER use ARRAY or ANY; use WHERE LOWER(col) = LOWER('value').
4. For top/best funds by profit: SELECT PortfolioName, SUM(PL_YTD) FROM holdings GROUP BY PortfolioName ORDER BY 2 DESC LIMIT 3
5. Profit should always be expressed as positive values; negate if needed.
6. Handle case sensitivity and typos using LOWER().
7. Available funds in trades: {trades_funds}
8. Available funds in holdings: {holdings_funds}
9. Always answer from data; do not fabricate results.
10. Return valid DuckDB SQL only.
Respond in JSON format ONLY:
{{
"sql": "SELECT ... FROM ...",
"explanation": "What this query does"
}}"""
response = client.chat.completions.create(
model="gpt-4o-mini",
max_tokens=600,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
)
raw = response.choices[0].message.content
try:
return json.loads(raw)
except json.JSONDecodeError:
logger.warning("Failed to parse generate_sql response: %s", raw)
return {"sql": "", "explanation": "Failed to generate SQL"}
def execute_query(sql: str) -> Dict[str, Any]:
"""Execute a SQL query against the in-memory DuckDB database."""
if not sql or not sql.strip():
return {"success": False, "data": [], "row_count": 0, "error": "Empty SQL query"}
try:
result = conn.execute(sql).fetchall()
description = conn.execute(sql).description
columns = [desc[0] for desc in description] if description else []
data = [dict(zip(columns, row)) for row in result]
return {"success": True, "data": data, "row_count": len(data), "error": None}
except Exception as exc:
logger.warning("Query execution error: %s | SQL: %s", exc, sql)
return {"success": False, "data": [], "row_count": 0, "error": str(exc)}
def explain_results(user_question: str, query_data: Dict[str, Any]) -> str:
"""Translate raw query results into a plain-English business answer."""
if not query_data["success"]:
return f"Sorry, I couldn't find the answer. Error: {query_data['error']}"
if query_data["row_count"] == 0:
return "Sorry, I could not find any matching data for your question."
client = _get_openai_client()
data_str = json.dumps(query_data["data"], indent=2)
prompt = f"""You are a financial analyst explaining query results to business users.
Original Question: {user_question}
Query Results ({query_data['row_count']} rows):
{data_str}
Instructions:
1. Answer the question directly and concisely.
2. Include specific numbers and fund names from the data.
3. Highlight significant findings.
4. Keep the answer under 150 words.
5. Use business-friendly plain text — no asterisks, bold, italics, or markdown.
6. Do NOT mention SQL queries or technical data structures.
"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
max_tokens=500,
temperature=0,
)
return response.choices[0].message.content
# ---------------------------------------------------------------------------
# Main pipeline
# ---------------------------------------------------------------------------
def financial_chatbot(user_question: str) -> Dict[str, Any]:
"""
Run the full chatbot pipeline for a user question.
Returns a dict with:
answer (str) – natural language answer
sql (str) – generated SQL (for transparency)
sources (list) – data sources used
error (str | None)
"""
logger.info("Processing question: %s", user_question)
source_decision = decide_data_source(user_question)
sources = source_decision.get("sources", ["trades", "holdings"])
sql_gen = generate_sql(user_question, sources)
sql = sql_gen.get("sql", "")
if not sql:
return {
"answer": "Sorry, I could not generate a query for your question. Please try rephrasing.",
"sql": "",
"sources": sources,
"error": "SQL generation failed",
}
exec_result = execute_query(sql)
answer = explain_results(user_question, exec_result)
return {
"answer": answer,
"sql": sql,
"sources": sources,
"error": exec_result.get("error"),
}