diff --git a/README.md b/README.md index ecac972..5e6777a 100644 --- a/README.md +++ b/README.md @@ -256,3 +256,5 @@ with tidb_client.session() as session: > Click the button below to install **TiDB MCP Server** in Cursor. Then, confirm by clicking **Install** when prompted. > > [![Install TiDB MCP Server](https://cursor.com/deeplink/mcp-install-dark.svg)](https://cursor.com/install-mcp?name=TiDB&config=eyJjb21tYW5kIjoidXZ4IC0tZnJvbSBweXRpZGJbbWNwXSB0aWRiLW1jcC1zZXJ2ZXIiLCJlbnYiOnsiVElEQl9IT1NUIjoibG9jYWxob3N0IiwiVElEQl9QT1JUIjoiNDAwMCIsIlRJREJfVVNFUk5BTUUiOiJyb290IiwiVElEQl9QQVNTV09SRCI6IiIsIlRJREJfREFUQUJBU0UiOiJ0ZXN0In19) +> +> To limit long-running MCP queries, set `TIDB_MCP_QUERY_TIMEOUT` to the maximum execution time in seconds. diff --git a/pytidb/ext/mcp/__init__.py b/pytidb/ext/mcp/__init__.py index a10e185..9d39d23 100644 --- a/pytidb/ext/mcp/__init__.py +++ b/pytidb/ext/mcp/__init__.py @@ -30,10 +30,18 @@ show_default=True, help="Port to bind for network transports", ) +@click.option( + "--query-timeout", + type=click.IntRange(min=1), + envvar="TIDB_MCP_QUERY_TIMEOUT", + default=None, + help="Maximum execution time for TiDB queries in seconds", +) def main( transport: Literal["stdio", "sse", "streamable-http"] = "stdio", host: str = "127.0.0.1", port: int = 8000, + query_timeout: int | None = None, ): logging.basicConfig( level=logging.INFO, @@ -47,5 +55,6 @@ def main( host=host, port=port, stateless_http=stateless, + query_timeout=query_timeout, ) mcp.run(transport=transport) diff --git a/pytidb/ext/mcp/server.py b/pytidb/ext/mcp/server.py index 9ab5ca2..ef2de46 100644 --- a/pytidb/ext/mcp/server.py +++ b/pytidb/ext/mcp/server.py @@ -25,6 +25,7 @@ # Constants TIDB_SERVERLESS_USERNAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+$") +MCP_QUERY_TIMEOUT: Optional[int] = None # TiDB Connector @@ -38,14 +39,23 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, database: Optional[str] = None, + query_timeout: Optional[int] = None, ): + self.query_timeout = query_timeout + connect_kwargs = {} + if query_timeout is not None: + connect_kwargs["connect_args"] = { + "init_command": f"SET SESSION max_execution_time = {query_timeout * 1000}" + } + self.tidb_client = TiDBClient.connect( - url=database_url, + database_url=database_url, host=host, port=port, username=username, password=password, database=database, + **connect_kwargs, ) if database_url: uri = MySQLDsn(database_url) @@ -53,7 +63,7 @@ def __init__( self.port = uri.port self.username = uri.username self.password = uri.password - self.database = uri.path.lstrip("/") + self.database = (uri.path or "").lstrip("/") or None else: self.host = host self.port = port @@ -70,12 +80,21 @@ def switch_database( username: Optional[str] = None, password: Optional[str] = None, ) -> None: + connect_kwargs = {} + if self.query_timeout is not None: + connect_kwargs["connect_args"] = { + "init_command": ( + f"SET SESSION max_execution_time = {self.query_timeout * 1000}" + ) + } + self.tidb_client = TiDBClient.connect( host=self.host, port=self.port, username=username or self.username, password=password or self.password, database=db_name or self.database, + **connect_kwargs, ) def show_tables(self) -> list[str]: @@ -98,7 +117,7 @@ def execute(self, sql_stmts: str | list[str]) -> list[dict]: @property def is_tidb_serverless(self) -> bool: - return TIDB_SERVERLESS_HOST_PATTERN.match(self.host) + return bool(self.host) and bool(TIDB_SERVERLESS_HOST_PATTERN.match(self.host)) def current_username(self) -> str: current_user = self.tidb_client.query("SELECT CURRENT_USER()").scalar() or "" @@ -161,6 +180,7 @@ async def app_lifespan(app: FastMCP) -> AsyncIterator[AppContext]: username=os.getenv("TIDB_USERNAME", "root"), password=os.getenv("TIDB_PASSWORD", ""), database=os.getenv("TIDB_DATABASE", "test"), + query_timeout=MCP_QUERY_TIMEOUT, ) log.info(f"Connected to TiDB: {tidb.host}:{tidb.port}/{tidb.database}") yield AppContext(tidb=tidb) @@ -250,8 +270,12 @@ def create_mcp_server( host: str = "127.0.0.1", port: int = 8000, stateless_http: bool = True, + query_timeout: Optional[int] = None, ) -> FastMCP: """Create and configure the TiDB MCP server.""" + global MCP_QUERY_TIMEOUT + MCP_QUERY_TIMEOUT = query_timeout + mcp = FastMCP( "tidb", instructions="""You are a tidb database expert, you can help me query, create, and execute sql diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..e3fa0ec --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,160 @@ +from unittest.mock import Mock + +import pytest + +pytest.importorskip("mcp.server.fastmcp") + +import pytidb.ext.mcp as mcp_cli +import pytidb.ext.mcp.server as mcp_server + + +def test_mcp_cli_option_forwards_query_timeout(monkeypatch): + captured = {} + + class DummyServer: + def run(self, transport): + captured["transport"] = transport + + def fake_create_mcp_server(**kwargs): + captured.update(kwargs) + return DummyServer() + + monkeypatch.setattr(mcp_cli, "create_mcp_server", fake_create_mcp_server) + + mcp_cli.main.main( + args=["--transport", "streamable-http", "--query-timeout", "12"], + prog_name="tidb-mcp-server", + standalone_mode=False, + ) + + assert captured == { + "host": "127.0.0.1", + "port": 8000, + "stateless_http": True, + "query_timeout": 12, + "transport": "streamable-http", + } + + +def test_mcp_cli_env_var_forwards_query_timeout(monkeypatch): + captured = {} + + class DummyServer: + def run(self, transport): + captured["transport"] = transport + + def fake_create_mcp_server(**kwargs): + captured.update(kwargs) + return DummyServer() + + monkeypatch.setattr(mcp_cli, "create_mcp_server", fake_create_mcp_server) + monkeypatch.setenv("TIDB_MCP_QUERY_TIMEOUT", "15") + + mcp_cli.main.main( + args=[], + prog_name="tidb-mcp-server", + standalone_mode=False, + ) + + assert captured == { + "host": "127.0.0.1", + "port": 8000, + "stateless_http": False, + "query_timeout": 15, + "transport": "stdio", + } + + +def test_tidb_connector_sets_query_timeout_init_command(monkeypatch): + calls = [] + + def fake_connect(**kwargs): + calls.append(kwargs) + return Mock() + + monkeypatch.setattr(mcp_server.TiDBClient, "connect", staticmethod(fake_connect)) + + mcp_server.TiDBConnector( + host="127.0.0.1", + port=4000, + username="root", + password="", + database="test", + query_timeout=7, + ) + + assert calls == [ + { + "database_url": None, + "host": "127.0.0.1", + "port": 4000, + "username": "root", + "password": "", + "database": "test", + "connect_args": { + "init_command": "SET SESSION max_execution_time = 7000" + }, + } + ] + + +def test_tidb_connector_preserves_query_timeout_when_switching_databases(monkeypatch): + calls = [] + + def fake_connect(**kwargs): + client = Mock() + client.disconnect = Mock() + calls.append(kwargs) + return client + + monkeypatch.setattr(mcp_server.TiDBClient, "connect", staticmethod(fake_connect)) + + connector = mcp_server.TiDBConnector( + host="127.0.0.1", + port=4000, + username="root", + password="", + database="test", + query_timeout=9, + ) + + connector.switch_database("analytics") + + assert calls[1] == { + "host": "127.0.0.1", + "port": 4000, + "username": "root", + "password": "", + "database": "analytics", + "connect_args": { + "init_command": "SET SESSION max_execution_time = 9000" + }, + } + + +def test_app_lifespan_passes_query_timeout(monkeypatch): + captured = {} + + class FakeConnector: + def __init__(self, **kwargs): + captured.update(kwargs) + self.host = kwargs["host"] + self.port = kwargs["port"] + self.database = kwargs["database"] + + def disconnect(self): + captured["disconnected"] = True + + monkeypatch.setattr(mcp_server, "TiDBConnector", FakeConnector) + monkeypatch.setattr(mcp_server, "MCP_QUERY_TIMEOUT", 15) + + async def run_lifespan(): + async with mcp_server.app_lifespan(Mock()): + pass + + import asyncio + + asyncio.run(run_lifespan()) + + assert captured["query_timeout"] == 15 + assert captured["disconnected"] is True