-
Notifications
You must be signed in to change notification settings - Fork 18
feat: add query timeout support to TiDB MCP server #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
|
|
||
| # Constants | ||
| TIDB_SERVERLESS_USERNAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+$") | ||
| MCP_QUERY_TIMEOUT: Optional[int] = None | ||
|
|
||
|
|
||
| # TiDB Connector | ||
|
|
@@ -38,22 +39,31 @@ def __init__( | |
| username: Optional[str] = None, | ||
| password: Optional[str] = None, | ||
| database: Optional[str] = None, | ||
| query_timeout: Optional[int] = None, | ||
| ): | ||
| self.query_timeout = query_timeout | ||
| connect_kwargs = {} | ||
| if query_timeout is not None: | ||
| connect_kwargs["connect_args"] = { | ||
| "init_command": f"SET SESSION max_execution_time = {query_timeout * 1000}" | ||
| } | ||
|
|
||
| self.tidb_client = TiDBClient.connect( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P0] Wrong kwarg to Why: Evidence: |
||
| url=database_url, | ||
| database_url=database_url, | ||
| host=host, | ||
| port=port, | ||
| username=username, | ||
| password=password, | ||
| database=database, | ||
| **connect_kwargs, | ||
| ) | ||
| if database_url: | ||
| uri = MySQLDsn(database_url) | ||
| self.host = uri.host | ||
| self.port = uri.port | ||
| self.username = uri.username | ||
| self.password = uri.password | ||
| self.database = uri.path.lstrip("/") | ||
| self.database = (uri.path or "").lstrip("/") or None | ||
| else: | ||
| self.host = host | ||
| self.port = port | ||
|
|
@@ -70,12 +80,21 @@ def switch_database( | |
| username: Optional[str] = None, | ||
| password: Optional[str] = None, | ||
| ) -> None: | ||
| connect_kwargs = {} | ||
| if self.query_timeout is not None: | ||
| connect_kwargs["connect_args"] = { | ||
| "init_command": ( | ||
| f"SET SESSION max_execution_time = {self.query_timeout * 1000}" | ||
| ) | ||
| } | ||
|
|
||
| self.tidb_client = TiDBClient.connect( | ||
| host=self.host, | ||
| port=self.port, | ||
| username=username or self.username, | ||
| password=password or self.password, | ||
| database=db_name or self.database, | ||
| **connect_kwargs, | ||
| ) | ||
|
|
||
| def show_tables(self) -> list[str]: | ||
|
|
@@ -98,7 +117,7 @@ def execute(self, sql_stmts: str | list[str]) -> list[dict]: | |
|
|
||
| @property | ||
| def is_tidb_serverless(self) -> bool: | ||
| return TIDB_SERVERLESS_HOST_PATTERN.match(self.host) | ||
| return bool(self.host) and bool(TIDB_SERVERLESS_HOST_PATTERN.match(self.host)) | ||
|
|
||
| def current_username(self) -> str: | ||
| current_user = self.tidb_client.query("SELECT CURRENT_USER()").scalar() or "" | ||
|
|
@@ -161,6 +180,7 @@ async def app_lifespan(app: FastMCP) -> AsyncIterator[AppContext]: | |
| username=os.getenv("TIDB_USERNAME", "root"), | ||
| password=os.getenv("TIDB_PASSWORD", ""), | ||
| database=os.getenv("TIDB_DATABASE", "test"), | ||
| query_timeout=MCP_QUERY_TIMEOUT, | ||
| ) | ||
| log.info(f"Connected to TiDB: {tidb.host}:{tidb.port}/{tidb.database}") | ||
| yield AppContext(tidb=tidb) | ||
|
|
@@ -250,8 +270,12 @@ def create_mcp_server( | |
| host: str = "127.0.0.1", | ||
| port: int = 8000, | ||
| stateless_http: bool = True, | ||
| query_timeout: Optional[int] = None, | ||
| ) -> FastMCP: | ||
| """Create and configure the TiDB MCP server.""" | ||
| global MCP_QUERY_TIMEOUT | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P2] Query timeout stored in module-global state (cross-instance leakage) Why: Multiple Evidence: |
||
| MCP_QUERY_TIMEOUT = query_timeout | ||
|
|
||
| mcp = FastMCP( | ||
| "tidb", | ||
| instructions="""You are a tidb database expert, you can help me query, create, and execute sql | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| from unittest.mock import Mock | ||
|
|
||
| import pytest | ||
|
|
||
| pytest.importorskip("mcp.server.fastmcp") | ||
|
|
||
| import pytidb.ext.mcp as mcp_cli | ||
| import pytidb.ext.mcp.server as mcp_server | ||
|
|
||
|
|
||
| def test_mcp_cli_option_forwards_query_timeout(monkeypatch): | ||
| captured = {} | ||
|
|
||
| class DummyServer: | ||
| def run(self, transport): | ||
| captured["transport"] = transport | ||
|
|
||
| def fake_create_mcp_server(**kwargs): | ||
| captured.update(kwargs) | ||
| return DummyServer() | ||
|
|
||
| monkeypatch.setattr(mcp_cli, "create_mcp_server", fake_create_mcp_server) | ||
|
|
||
| mcp_cli.main.main( | ||
| args=["--transport", "streamable-http", "--query-timeout", "12"], | ||
| prog_name="tidb-mcp-server", | ||
| standalone_mode=False, | ||
| ) | ||
|
|
||
| assert captured == { | ||
| "host": "127.0.0.1", | ||
| "port": 8000, | ||
| "stateless_http": True, | ||
| "query_timeout": 12, | ||
| "transport": "streamable-http", | ||
| } | ||
|
|
||
|
|
||
| def test_mcp_cli_env_var_forwards_query_timeout(monkeypatch): | ||
| captured = {} | ||
|
|
||
| class DummyServer: | ||
| def run(self, transport): | ||
| captured["transport"] = transport | ||
|
|
||
| def fake_create_mcp_server(**kwargs): | ||
| captured.update(kwargs) | ||
| return DummyServer() | ||
|
|
||
| monkeypatch.setattr(mcp_cli, "create_mcp_server", fake_create_mcp_server) | ||
| monkeypatch.setenv("TIDB_MCP_QUERY_TIMEOUT", "15") | ||
|
|
||
| mcp_cli.main.main( | ||
| args=[], | ||
| prog_name="tidb-mcp-server", | ||
| standalone_mode=False, | ||
| ) | ||
|
|
||
| assert captured == { | ||
| "host": "127.0.0.1", | ||
| "port": 8000, | ||
| "stateless_http": False, | ||
| "query_timeout": 15, | ||
| "transport": "stdio", | ||
| } | ||
|
|
||
|
|
||
| def test_tidb_connector_sets_query_timeout_init_command(monkeypatch): | ||
| calls = [] | ||
|
|
||
| def fake_connect(**kwargs): | ||
| calls.append(kwargs) | ||
| return Mock() | ||
|
|
||
| monkeypatch.setattr(mcp_server.TiDBClient, "connect", staticmethod(fake_connect)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P1] New tests encode the wrong connect API and mask the production failure Why: Tests monkeypatch Evidence: |
||
|
|
||
| mcp_server.TiDBConnector( | ||
| host="127.0.0.1", | ||
| port=4000, | ||
| username="root", | ||
| password="", | ||
| database="test", | ||
| query_timeout=7, | ||
| ) | ||
|
|
||
| assert calls == [ | ||
| { | ||
| "database_url": None, | ||
| "host": "127.0.0.1", | ||
| "port": 4000, | ||
| "username": "root", | ||
| "password": "", | ||
| "database": "test", | ||
| "connect_args": { | ||
| "init_command": "SET SESSION max_execution_time = 7000" | ||
| }, | ||
| } | ||
| ] | ||
|
|
||
|
|
||
| def test_tidb_connector_preserves_query_timeout_when_switching_databases(monkeypatch): | ||
| calls = [] | ||
|
|
||
| def fake_connect(**kwargs): | ||
| client = Mock() | ||
| client.disconnect = Mock() | ||
| calls.append(kwargs) | ||
| return client | ||
|
|
||
| monkeypatch.setattr(mcp_server.TiDBClient, "connect", staticmethod(fake_connect)) | ||
|
|
||
| connector = mcp_server.TiDBConnector( | ||
| host="127.0.0.1", | ||
| port=4000, | ||
| username="root", | ||
| password="", | ||
| database="test", | ||
| query_timeout=9, | ||
| ) | ||
|
|
||
| connector.switch_database("analytics") | ||
|
|
||
| assert calls[1] == { | ||
| "host": "127.0.0.1", | ||
| "port": 4000, | ||
| "username": "root", | ||
| "password": "", | ||
| "database": "analytics", | ||
| "connect_args": { | ||
| "init_command": "SET SESSION max_execution_time = 9000" | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| def test_app_lifespan_passes_query_timeout(monkeypatch): | ||
| captured = {} | ||
|
|
||
| class FakeConnector: | ||
| def __init__(self, **kwargs): | ||
| captured.update(kwargs) | ||
| self.host = kwargs["host"] | ||
| self.port = kwargs["port"] | ||
| self.database = kwargs["database"] | ||
|
|
||
| def disconnect(self): | ||
| captured["disconnected"] = True | ||
|
|
||
| monkeypatch.setattr(mcp_server, "TiDBConnector", FakeConnector) | ||
| monkeypatch.setattr(mcp_server, "MCP_QUERY_TIMEOUT", 15) | ||
|
|
||
| async def run_lifespan(): | ||
| async with mcp_server.app_lifespan(Mock()): | ||
| pass | ||
|
|
||
| import asyncio | ||
|
|
||
| asyncio.run(run_lifespan()) | ||
|
|
||
| assert captured["query_timeout"] == 15 | ||
| assert captured["disconnected"] is True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is recommended that we add a subsection about advanced settings in the MCP Server section of the official documentation.
https://github.com/pingcap/docs/edit/master/ai/integrations/tidb-mcp-server.md