Skip to content

Commit 463fc5b

Browse files
committed
refactor: standardize JSON output and add error handling across tools
- Convert DataFrame results to JSON format using to_json() for consistent output - Add try/except blocks with descriptive error messages to all query functions - Update docstrings to reflect JSON-formatted return values - Add new schema tools: list_sinks() and show_create_sink() - Import json module in query_tools for enhanced JSON serialization
1 parent 8d03826 commit 463fc5b

File tree

4 files changed

+116
-37
lines changed

4 files changed

+116
-37
lines changed

src/explain.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def explain_analyze(query: str) -> str:
1313
query: The SQL query to analyze streaming job (TABLE, MATERIALIZED VIEW, SINK, INDEX, or ID)
1414
1515
Returns:
16-
Query execution plan with actual statistics
16+
Query execution plan with detailed runtime statistics
1717
"""
1818
query_upper = query.strip().upper()
1919
allowed_keywords = ["TABLE", "MATERIALIZED VIEW", "SINK", "INDEX"]
@@ -26,7 +26,7 @@ def explain_analyze(query: str) -> str:
2626
try:
2727
explain_query = f"EXPLAIN ANALYZE {query}"
2828
result = rw.fetch(explain_query, format=OutputFormat.DATAFRAME)
29-
return result
29+
return result.to_json()
3030
except Exception as e:
3131
return f"Error executing EXPLAIN ANALYZE: {str(e)}"
3232

@@ -42,7 +42,7 @@ def explain_query(query: str) -> str:
4242
query: The SQL query to explain (SELECT, INSERT, UPDATE, DELETE)
4343
4444
Returns:
45-
Query execution plan with estimated statistics
45+
Query execution plan
4646
"""
4747
query_upper = query.strip().upper()
4848
allowed_queries = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'WITH']
@@ -59,6 +59,6 @@ def explain_query(query: str) -> str:
5959
try:
6060
explain_query = f"EXPLAIN {query}"
6161
result = rw.fetch(explain_query, format=OutputFormat.DATAFRAME)
62-
return result
62+
return result.to_json()
6363
except Exception as e:
6464
return f"Error executing EXPLAIN: {str(e)}"

src/management_tools.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ def get_database_version() -> str:
1515
Database version information
1616
"""
1717
rw = setup_risingwave_connection()
18-
result = rw.fetchone("SELECT version()", format=OutputFormat.DATAFRAME)
19-
return result
18+
try:
19+
result = rw.fetchone("SELECT version()", format=OutputFormat.DATAFRAME)
20+
return result.to_json()
21+
except Exception as e:
22+
return f"Error getting database version: {str(e)}"
2023

2124
@mcp.tool
2225
def show_running_queries() -> str:
@@ -31,7 +34,7 @@ def show_running_queries() -> str:
3134
# This may not be available in all RisingWave versions
3235
result = rw.fetch("SHOW PROCESSLIST",
3336
format=OutputFormat.DATAFRAME)
34-
return result
37+
return result.to_json()
3538
except Exception as e:
3639
return f"Show running queries not supported or error occurred: {str(e)}"
3740

src/query_tools.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from risingwave import OutputFormat
22
from connection import setup_risingwave_connection
3+
import json
34

45

56
def run_select_query(query: str) -> str:
@@ -10,16 +11,20 @@ def run_select_query(query: str) -> str:
1011
query: The SELECT SQL query to execute (must start with SELECT)
1112
1213
Returns:
13-
Query results as a formatted string
14+
Query results as a JSON-formatted string
1415
"""
1516
query_upper = query.strip().upper()
1617
if not query_upper.startswith('SELECT'):
1718
raise ValueError(
1819
"Only SELECT queries are allowed for security reasons")
1920

2021
rw = setup_risingwave_connection()
21-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
22-
return result
22+
try:
23+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
24+
records = result.to_dict(orient='records')
25+
return json.dumps(records, default=str, ensure_ascii=False, indent=2)
26+
except Exception as e:
27+
return f"Error executing SELECT query: {str(e)}"
2328

2429

2530
def table_row_count(table_name: str) -> str:
@@ -30,12 +35,16 @@ def table_row_count(table_name: str) -> str:
3035
table_name: Name of the table
3136
3237
Returns:
33-
Row count as a string
38+
Row count as a JSON-formatted string
3439
"""
3540
rw = setup_risingwave_connection()
3641
query = f"SELECT COUNT(*) as row_count FROM {table_name}"
37-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
38-
return result
42+
try:
43+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
44+
records = result.to_dict(orient='records')
45+
return json.dumps(records, default=str, ensure_ascii=False)
46+
except Exception as e:
47+
return f"Error getting row count for table {table_name}: {str(e)}"
3948

4049

4150
def get_table_stats(table_name: str, schema_name: str = "public") -> str:
@@ -61,7 +70,10 @@ def get_table_stats(table_name: str, schema_name: str = "public") -> str:
6170
FROM information_schema.columns
6271
WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'
6372
"""
64-
column_info = rw.fetchone(column_query, format=OutputFormat.DATAFRAME)
73+
try:
74+
column_info = rw.fetchone(column_query, format=OutputFormat.DATAFRAME)
75+
except Exception as e:
76+
return f"Error getting column info for table {table_name}: {str(e)}"
6577

6678
stats = {
6779
"table": f"{schema_name}.{table_name}",

src/schema_tools.py

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@ def register_schema_tools(mcp: FastMCP):
1010
def show_tables() -> str:
1111
"""List all tables in the database."""
1212
rw = setup_risingwave_connection()
13-
result = rw.fetch("SHOW TABLES", format=OutputFormat.DATAFRAME)
14-
return result
13+
try:
14+
result = rw.fetch("SHOW TABLES", format=OutputFormat.DATAFRAME)
15+
return result.to_json()
16+
except Exception as e:
17+
return f"Error listing tables: {str(e)}"
1518

1619
@mcp.tool
1720
def list_databases() -> str:
1821
"""List all databases in the RisingWave cluster."""
1922
rw = setup_risingwave_connection()
20-
result = rw.fetch("SHOW DATABASES", format=OutputFormat.DATAFRAME)
21-
return result
23+
try:
24+
result = rw.fetch("SHOW DATABASES", format=OutputFormat.DATAFRAME)
25+
return result.to_json()
26+
except Exception as e:
27+
return f"Error listing databases: {str(e)}"
2228

2329
@mcp.tool
2430
def describe_table(table_name: str) -> str:
@@ -33,8 +39,11 @@ def describe_table(table_name: str) -> str:
3339
"""
3440
rw = setup_risingwave_connection()
3541
query = f"DESCRIBE {table_name}"
36-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
37-
return result
42+
try:
43+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
44+
return result.to_json()
45+
except Exception as e:
46+
return f"Error describing table {table_name}: {str(e)}"
3847

3948
@mcp.tool
4049
def describe_materialized_view(mv_name: str) -> str:
@@ -49,8 +58,11 @@ def describe_materialized_view(mv_name: str) -> str:
4958
"""
5059
rw = setup_risingwave_connection()
5160
query = f"DESCRIBE {mv_name}"
52-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
53-
return result
61+
try:
62+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
63+
return result.to_json()
64+
except Exception as e:
65+
return f"Error describing materialized view {mv_name}: {str(e)}"
5466

5567
@mcp.tool
5668
def show_create_table(table_name: str) -> str:
@@ -65,8 +77,11 @@ def show_create_table(table_name: str) -> str:
6577
"""
6678
rw = setup_risingwave_connection()
6779
query = f"SHOW CREATE TABLE {table_name}"
68-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
69-
return result
80+
try:
81+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
82+
return result.to_json()
83+
except Exception as e:
84+
return f"Error showing create table for {table_name}: {str(e)}"
7085

7186
@mcp.tool
7287
def show_create_materialized_view(mv_name: str) -> str:
@@ -81,8 +96,11 @@ def show_create_materialized_view(mv_name: str) -> str:
8196
"""
8297
rw = setup_risingwave_connection()
8398
query = f"SHOW CREATE MATERIALIZED VIEW {mv_name}"
84-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
85-
return result
99+
try:
100+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
101+
return result.to_json()
102+
except Exception as e:
103+
return f"Error showing create materialized view for {mv_name}: {str(e)}"
86104

87105
@mcp.tool
88106
def check_table_exists(table_name: str, schema_name: str = "public") -> str:
@@ -97,8 +115,11 @@ def check_table_exists(table_name: str, schema_name: str = "public") -> str:
97115
Boolean result as string indicating if table exists
98116
"""
99117
rw = setup_risingwave_connection()
100-
exists = rw.check_exist(name=table_name, schema_name=schema_name)
101-
return f"Table '{table_name}' in schema '{schema_name}' exists: {exists}"
118+
try:
119+
exists = rw.check_exist(name=table_name, schema_name=schema_name)
120+
return f"Table '{table_name}' in schema '{schema_name}' exists: {exists}"
121+
except Exception as e:
122+
return f"Error checking if table exists: {str(e)}"
102123

103124
@mcp.tool
104125
def list_schemas() -> str:
@@ -109,9 +130,12 @@ def list_schemas() -> str:
109130
List of schemas as a formatted string
110131
"""
111132
rw = setup_risingwave_connection()
112-
result = rw.fetch(
113-
"SELECT schema_name FROM information_schema.schemata", format=OutputFormat.DATAFRAME)
114-
return result
133+
try:
134+
result = rw.fetch(
135+
"SELECT schema_name FROM information_schema.schemata", format=OutputFormat.DATAFRAME)
136+
return result.to_json()
137+
except Exception as e:
138+
return f"Error listing schemas: {str(e)}"
115139

116140
@mcp.tool
117141
def list_materialized_views() -> str:
@@ -126,8 +150,11 @@ def list_materialized_views() -> str:
126150
"""
127151
rw = setup_risingwave_connection()
128152
query = "SHOW MATERIALIZED VIEWS"
129-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
130-
return result
153+
try:
154+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
155+
return result.to_json()
156+
except Exception as e:
157+
return f"Error listing materialized views: {str(e)}"
131158

132159
@mcp.tool
133160
def get_table_columns(table_name: str, schema_name: str = "public") -> str:
@@ -148,8 +175,11 @@ def get_table_columns(table_name: str, schema_name: str = "public") -> str:
148175
WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'
149176
ORDER BY ordinal_position
150177
"""
151-
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
152-
return result
178+
try:
179+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
180+
return result.to_json()
181+
except Exception as e:
182+
return f"Error getting table columns for {table_name}: {str(e)}"
153183

154184
@mcp.tool
155185
def list_subscriptions(schema_name: str = "public") -> str:
@@ -166,7 +196,7 @@ def list_subscriptions(schema_name: str = "public") -> str:
166196
query = f"SHOW SUBSCRIPTIONS FROM {schema_name}"
167197
try:
168198
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
169-
return result
199+
return result.to_json()
170200
except Exception as e:
171201
return f"Error listing subscriptions: {str(e)}"
172202

@@ -190,6 +220,40 @@ def list_table_privileges(table_name: str, schema_name: str = "public") -> str:
190220
"""
191221
try:
192222
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
193-
return result
223+
return result.to_json()
194224
except Exception as e:
195225
return f"Error getting table privileges: {str(e)}"
226+
227+
@mcp.tool
228+
def list_sinks() -> str:
229+
"""
230+
List all sinks in the RisingWave database.
231+
232+
Returns:
233+
List of sinks as a formatted string
234+
"""
235+
rw = setup_risingwave_connection()
236+
query = "SHOW SINKS"
237+
try:
238+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
239+
return result.to_json()
240+
except Exception as e:
241+
return f"Error listing sinks: {str(e)}"
242+
243+
@mcp.tool
244+
def show_create_sink(sink_name: str) -> str:
245+
"""
246+
Show the CREATE SINK statement for a specific sink.
247+
248+
Args:
249+
sink_name: Name of the sink
250+
Returns:
251+
CREATE SINK statement
252+
"""
253+
rw = setup_risingwave_connection()
254+
query = f"SHOW CREATE SINK {sink_name}"
255+
try:
256+
result = rw.fetch(query, format=OutputFormat.DATAFRAME)
257+
return result.to_json()
258+
except Exception as e:
259+
return f"Error showing create sink: {str(e)}"

0 commit comments

Comments
 (0)