Skip to content

Commit 0abe0ee

Browse files
authored
Merge commit from fork
Add input validation for table_name and sample_size parameters to prevent KQL injection attacks (GHSA-vphc-468g-8rfp). The table_name parameter was interpolated directly into KQL queries via f-strings, allowing arbitrary query execution through pipe, newline, and semicolon injection. - Add validate_table_name() with strict regex allowlist - Add validate_sample_size() for positive integer enforcement - Add 24 tests covering injection vectors and valid inputs
1 parent 48b2933 commit 0abe0ee

2 files changed

Lines changed: 200 additions & 1 deletion

File tree

src/adx_mcp_server/server.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import os
8+
import re
89
import sys
910
from typing import Any, Dict, List, Optional
1011
from dataclasses import dataclass
@@ -170,6 +171,31 @@ def format_query_results(result_set) -> List[Dict[str, Any]]:
170171
)
171172
raise
172173

174+
_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$')
175+
176+
def validate_table_name(table_name: str) -> str:
177+
"""Validate a KQL table name to prevent injection attacks.
178+
179+
Allows simple identifiers (my_table) and dot-qualified names (database.table).
180+
Rejects any characters that could enable KQL injection.
181+
"""
182+
if not table_name or not table_name.strip():
183+
raise ValueError("Table name cannot be empty")
184+
table_name = table_name.strip()
185+
if not _TABLE_NAME_PATTERN.match(table_name):
186+
raise ValueError(
187+
f"Invalid table name: '{table_name}'. "
188+
"Table names must contain only letters, digits, underscores, "
189+
"and dots (for qualified names like 'database.table')."
190+
)
191+
return table_name
192+
193+
def validate_sample_size(sample_size: int) -> int:
194+
"""Validate sample_size is a positive integer."""
195+
if not isinstance(sample_size, int) or sample_size <= 0:
196+
raise ValueError(f"sample_size must be a positive integer, got: {sample_size}")
197+
return sample_size
198+
173199
@mcp.tool(description="Executes a Kusto Query Language (KQL) query against the configured Azure Data Explorer database and returns the results as a list of dictionaries.")
174200
async def execute_query(query: str) -> List[Dict[str, Any]]:
175201
"""Execute a KQL query against the configured ADX database."""
@@ -217,6 +243,7 @@ async def list_tables() -> List[Dict[str, Any]]:
217243
@mcp.tool(description="Retrieves the schema information for a specified table in the Azure Data Explorer database, including column names, data types, and other schema-related metadata.")
218244
async def get_table_schema(table_name: str) -> List[Dict[str, Any]]:
219245
"""Get schema information for a specific table."""
246+
table_name = validate_table_name(table_name)
220247
logger.info("Getting table schema", table_name=table_name, database=config.database)
221248

222249
if not config.cluster_url or not config.database:
@@ -237,6 +264,8 @@ async def get_table_schema(table_name: str) -> List[Dict[str, Any]]:
237264
@mcp.tool(description="Retrieves a random sample of rows from the specified table in the Azure Data Explorer database. The sample_size parameter controls how many rows to return (default: 10).")
238265
async def sample_table_data(table_name: str, sample_size: int = 10) -> List[Dict[str, Any]]:
239266
"""Get sample data from a table."""
267+
table_name = validate_table_name(table_name)
268+
sample_size = validate_sample_size(sample_size)
240269
logger.info("Sampling table data", table_name=table_name, sample_size=sample_size, database=config.database)
241270

242271
if not config.cluster_url or not config.database:
@@ -257,6 +286,7 @@ async def sample_table_data(table_name: str, sample_size: int = 10) -> List[Dict
257286
@mcp.tool(description="Retrieves table details including TotalRowCount, HotExtentSize")
258287
async def get_table_details(table_name: str) -> List[Dict[str, Any]]:
259288
"""Get detailed statistics and metadata for a table."""
289+
table_name = validate_table_name(table_name)
260290
logger.info("Getting table details", table_name=table_name, database=config.database)
261291

262292
if not config.cluster_url or not config.database:

tests/test_all_tools.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from unittest.mock import patch, MagicMock
88

9-
from adx_mcp_server.server import config
9+
from adx_mcp_server.server import config, validate_table_name, validate_sample_size
1010

1111

1212
class TestListTablesTool:
@@ -352,3 +352,172 @@ async def test_get_table_details_error(self):
352352
finally:
353353
config.cluster_url = original_url
354354
config.database = original_db
355+
356+
357+
class TestValidateTableName:
358+
"""Tests for table name validation to prevent KQL injection."""
359+
360+
def test_simple_table_name(self):
361+
assert validate_table_name("my_table") == "my_table"
362+
363+
def test_qualified_table_name(self):
364+
assert validate_table_name("database.table") == "database.table"
365+
366+
def test_multi_qualified_name(self):
367+
assert validate_table_name("db.schema.table") == "db.schema.table"
368+
369+
def test_underscore_prefix(self):
370+
assert validate_table_name("_private_table") == "_private_table"
371+
372+
def test_alphanumeric(self):
373+
assert validate_table_name("table123") == "table123"
374+
375+
def test_strips_whitespace(self):
376+
assert validate_table_name(" my_table ") == "my_table"
377+
378+
def test_pipe_injection(self):
379+
with pytest.raises(ValueError, match="Invalid table name"):
380+
validate_table_name("sensitive_data | project Secret | take 100 //")
381+
382+
def test_newline_injection(self):
383+
with pytest.raises(ValueError, match="Invalid table name"):
384+
validate_table_name("users\n.drop table critical_data")
385+
386+
def test_semicolon_injection(self):
387+
with pytest.raises(ValueError, match="Invalid table name"):
388+
validate_table_name("table; .drop table other")
389+
390+
def test_bracket_notation(self):
391+
with pytest.raises(ValueError, match="Invalid table name"):
392+
validate_table_name("['injected query']")
393+
394+
def test_space_in_name(self):
395+
with pytest.raises(ValueError, match="Invalid table name"):
396+
validate_table_name("table name with spaces")
397+
398+
def test_empty_string(self):
399+
with pytest.raises(ValueError, match="cannot be empty"):
400+
validate_table_name("")
401+
402+
def test_whitespace_only(self):
403+
with pytest.raises(ValueError, match="cannot be empty"):
404+
validate_table_name(" ")
405+
406+
def test_starts_with_digit(self):
407+
with pytest.raises(ValueError, match="Invalid table name"):
408+
validate_table_name("123table")
409+
410+
def test_hyphen_in_name(self):
411+
with pytest.raises(ValueError, match="Invalid table name"):
412+
validate_table_name("my-table")
413+
414+
def test_slash_comment_injection(self):
415+
with pytest.raises(ValueError, match="Invalid table name"):
416+
validate_table_name("table // comment")
417+
418+
def test_trailing_dot(self):
419+
with pytest.raises(ValueError, match="Invalid table name"):
420+
validate_table_name("database.")
421+
422+
def test_leading_dot(self):
423+
with pytest.raises(ValueError, match="Invalid table name"):
424+
validate_table_name(".table")
425+
426+
427+
class TestValidateSampleSize:
428+
"""Tests for sample_size validation."""
429+
430+
def test_valid_sample_size(self):
431+
assert validate_sample_size(10) == 10
432+
433+
def test_zero(self):
434+
with pytest.raises(ValueError, match="sample_size must be a positive integer"):
435+
validate_sample_size(0)
436+
437+
def test_negative(self):
438+
with pytest.raises(ValueError, match="sample_size must be a positive integer"):
439+
validate_sample_size(-5)
440+
441+
442+
class TestTableNameInjectionPrevention:
443+
"""Integration tests proving tool handlers reject KQL injection payloads."""
444+
445+
@pytest.mark.asyncio
446+
async def test_get_table_schema_rejects_injection(self):
447+
original_url = config.cluster_url
448+
original_db = config.database
449+
config.cluster_url = "https://test.kusto.windows.net"
450+
config.database = "testdb"
451+
452+
try:
453+
from adx_mcp_server import server
454+
fn = server.get_table_schema
455+
456+
with pytest.raises(ValueError, match="Invalid table name"):
457+
if hasattr(fn, 'fn'):
458+
await fn.fn("sensitive_data | project Secret | take 100 //")
459+
else:
460+
await fn("sensitive_data | project Secret | take 100 //")
461+
finally:
462+
config.cluster_url = original_url
463+
config.database = original_db
464+
465+
@pytest.mark.asyncio
466+
async def test_sample_table_data_rejects_injection(self):
467+
original_url = config.cluster_url
468+
original_db = config.database
469+
config.cluster_url = "https://test.kusto.windows.net"
470+
config.database = "testdb"
471+
472+
try:
473+
from adx_mcp_server import server
474+
fn = server.sample_table_data
475+
476+
with pytest.raises(ValueError, match="Invalid table name"):
477+
if hasattr(fn, 'fn'):
478+
await fn.fn("data | take 100 //", 10)
479+
else:
480+
await fn("data | take 100 //", 10)
481+
finally:
482+
config.cluster_url = original_url
483+
config.database = original_db
484+
485+
@pytest.mark.asyncio
486+
async def test_sample_table_data_rejects_invalid_sample_size(self):
487+
original_url = config.cluster_url
488+
original_db = config.database
489+
config.cluster_url = "https://test.kusto.windows.net"
490+
config.database = "testdb"
491+
492+
try:
493+
from adx_mcp_server import server
494+
fn = server.sample_table_data
495+
496+
with pytest.raises(ValueError, match="sample_size must be a positive integer"):
497+
if hasattr(fn, 'fn'):
498+
await fn.fn("valid_table", -1)
499+
else:
500+
await fn("valid_table", -1)
501+
finally:
502+
config.cluster_url = original_url
503+
config.database = original_db
504+
505+
@pytest.mark.asyncio
506+
async def test_get_table_details_rejects_injection(self):
507+
original_url = config.cluster_url
508+
original_db = config.database
509+
config.cluster_url = "https://test.kusto.windows.net"
510+
config.database = "testdb"
511+
512+
try:
513+
from adx_mcp_server import server
514+
fn = server.get_table_details
515+
516+
with pytest.raises(ValueError, match="Invalid table name"):
517+
if hasattr(fn, 'fn'):
518+
await fn.fn("users details\n.drop table critical_data")
519+
else:
520+
await fn("users details\n.drop table critical_data")
521+
finally:
522+
config.cluster_url = original_url
523+
config.database = original_db

0 commit comments

Comments
 (0)