Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,5 @@ with tidb_client.session() as session:
> Click the button below to install **TiDB MCP Server** in Cursor. Then, confirm by clicking **Install** when prompted.
>
> [![Install TiDB MCP Server](https://cursor.com/deeplink/mcp-install-dark.svg)](https://cursor.com/install-mcp?name=TiDB&config=eyJjb21tYW5kIjoidXZ4IC0tZnJvbSBweXRpZGJbbWNwXSB0aWRiLW1jcC1zZXJ2ZXIiLCJlbnYiOnsiVElEQl9IT1NUIjoibG9jYWxob3N0IiwiVElEQl9QT1JUIjoiNDAwMCIsIlRJREJfVVNFUk5BTUUiOiJyb290IiwiVElEQl9QQVNTV09SRCI6IiIsIlRJREJfREFUQUJBU0UiOiJ0ZXN0In19)
>
> To limit long-running MCP queries, set `TIDB_MCP_QUERY_TIMEOUT` to the maximum execution time in seconds.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended that we add a subsection about advanced settings in the MCP Server section of the official documentation.

https://github.com/pingcap/docs/edit/master/ai/integrations/tidb-mcp-server.md

9 changes: 9 additions & 0 deletions pytidb/ext/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@
show_default=True,
help="Port to bind for network transports",
)
@click.option(
"--query-timeout",
type=click.IntRange(min=1),
envvar="TIDB_MCP_QUERY_TIMEOUT",
default=None,
help="Maximum execution time for TiDB queries in seconds",
)
def main(
transport: Literal["stdio", "sse", "streamable-http"] = "stdio",
host: str = "127.0.0.1",
port: int = 8000,
query_timeout: int | None = None,
):
logging.basicConfig(
level=logging.INFO,
Expand All @@ -47,5 +55,6 @@ def main(
host=host,
port=port,
stateless_http=stateless,
query_timeout=query_timeout,
)
mcp.run(transport=transport)
30 changes: 27 additions & 3 deletions pytidb/ext/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# Constants
TIDB_SERVERLESS_USERNAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+$")
MCP_QUERY_TIMEOUT: Optional[int] = None


# TiDB Connector
Expand All @@ -38,22 +39,31 @@ def __init__(
username: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
query_timeout: Optional[int] = None,
):
self.query_timeout = query_timeout
connect_kwargs = {}
if query_timeout is not None:
connect_kwargs["connect_args"] = {
"init_command": f"SET SESSION max_execution_time = {query_timeout * 1000}"
}

self.tidb_client = TiDBClient.connect(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Wrong kwarg to TiDBClient.connect breaks MCP server startup

Why: database_url is not a parameter of TiDBClient.connect. The method signature expects url, not database_url. This kwarg flows into **kwargs and is forwarded to SQLAlchemy create_engine, which raises TypeError for unexpected keyword arguments, preventing any MCP server DB connection.

Evidence: pytidb/ext/mcp/server.py:51 passes TiDBClient.connect(database_url=database_url, ...) but pytidb/client.py:49 defines def connect(..., url: Optional[str] = None, ..., **kwargs) and pytidb/client.py:88 calls create_engine(url, echo=debug, **kwargs). The correct parameter name is url, not database_url.

url=database_url,
database_url=database_url,
host=host,
port=port,
username=username,
password=password,
database=database,
**connect_kwargs,
)
if database_url:
uri = MySQLDsn(database_url)
self.host = uri.host
self.port = uri.port
self.username = uri.username
self.password = uri.password
self.database = uri.path.lstrip("/")
self.database = (uri.path or "").lstrip("/") or None
else:
self.host = host
self.port = port
Expand All @@ -70,12 +80,21 @@ def switch_database(
username: Optional[str] = None,
password: Optional[str] = None,
) -> None:
connect_kwargs = {}
if self.query_timeout is not None:
connect_kwargs["connect_args"] = {
"init_command": (
f"SET SESSION max_execution_time = {self.query_timeout * 1000}"
)
}

self.tidb_client = TiDBClient.connect(
host=self.host,
port=self.port,
username=username or self.username,
password=password or self.password,
database=db_name or self.database,
**connect_kwargs,
)

def show_tables(self) -> list[str]:
Expand All @@ -98,7 +117,7 @@ def execute(self, sql_stmts: str | list[str]) -> list[dict]:

@property
def is_tidb_serverless(self) -> bool:
return TIDB_SERVERLESS_HOST_PATTERN.match(self.host)
return bool(self.host) and bool(TIDB_SERVERLESS_HOST_PATTERN.match(self.host))

def current_username(self) -> str:
current_user = self.tidb_client.query("SELECT CURRENT_USER()").scalar() or ""
Expand Down Expand Up @@ -161,6 +180,7 @@ async def app_lifespan(app: FastMCP) -> AsyncIterator[AppContext]:
username=os.getenv("TIDB_USERNAME", "root"),
password=os.getenv("TIDB_PASSWORD", ""),
database=os.getenv("TIDB_DATABASE", "test"),
query_timeout=MCP_QUERY_TIMEOUT,
)
log.info(f"Connected to TiDB: {tidb.host}:{tidb.port}/{tidb.database}")
yield AppContext(tidb=tidb)
Expand Down Expand Up @@ -250,8 +270,12 @@ def create_mcp_server(
host: str = "127.0.0.1",
port: int = 8000,
stateless_http: bool = True,
query_timeout: Optional[int] = None,
) -> FastMCP:
"""Create and configure the TiDB MCP server."""
global MCP_QUERY_TIMEOUT
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] Query timeout stored in module-global state (cross-instance leakage)

Why: Multiple create_mcp_server() calls can overwrite MCP_QUERY_TIMEOUT, making behavior order-dependent and potentially applying the wrong timeout to a running server. When server A is created with timeout X, then server B with timeout Y, and then A is started, it will read timeout Y instead of X.

Evidence: pytidb/ext/mcp/server.py:276 uses global MCP_QUERY_TIMEOUT; MCP_QUERY_TIMEOUT = query_timeout and pytidb/ext/mcp/server.py:183 reads query_timeout=MCP_QUERY_TIMEOUT. This creates cross-instance state leakage in multi-server scenarios.

MCP_QUERY_TIMEOUT = query_timeout

mcp = FastMCP(
"tidb",
instructions="""You are a tidb database expert, you can help me query, create, and execute sql
Expand Down
160 changes: 160 additions & 0 deletions tests/test_mcp_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from unittest.mock import Mock

import pytest

pytest.importorskip("mcp.server.fastmcp")

import pytidb.ext.mcp as mcp_cli
import pytidb.ext.mcp.server as mcp_server


def test_mcp_cli_option_forwards_query_timeout(monkeypatch):
captured = {}

class DummyServer:
def run(self, transport):
captured["transport"] = transport

def fake_create_mcp_server(**kwargs):
captured.update(kwargs)
return DummyServer()

monkeypatch.setattr(mcp_cli, "create_mcp_server", fake_create_mcp_server)

mcp_cli.main.main(
args=["--transport", "streamable-http", "--query-timeout", "12"],
prog_name="tidb-mcp-server",
standalone_mode=False,
)

assert captured == {
"host": "127.0.0.1",
"port": 8000,
"stateless_http": True,
"query_timeout": 12,
"transport": "streamable-http",
}


def test_mcp_cli_env_var_forwards_query_timeout(monkeypatch):
captured = {}

class DummyServer:
def run(self, transport):
captured["transport"] = transport

def fake_create_mcp_server(**kwargs):
captured.update(kwargs)
return DummyServer()

monkeypatch.setattr(mcp_cli, "create_mcp_server", fake_create_mcp_server)
monkeypatch.setenv("TIDB_MCP_QUERY_TIMEOUT", "15")

mcp_cli.main.main(
args=[],
prog_name="tidb-mcp-server",
standalone_mode=False,
)

assert captured == {
"host": "127.0.0.1",
"port": 8000,
"stateless_http": False,
"query_timeout": 15,
"transport": "stdio",
}


def test_tidb_connector_sets_query_timeout_init_command(monkeypatch):
calls = []

def fake_connect(**kwargs):
calls.append(kwargs)
return Mock()

monkeypatch.setattr(mcp_server.TiDBClient, "connect", staticmethod(fake_connect))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] New tests encode the wrong connect API and mask the production failure

Why: Tests monkeypatch TiDBClient.connect as a staticmethod and assert a "database_url" kwarg (line 88), so CI can pass while real TiDBClient.connect crashes. The test encodes the wrong API contract and will fail once the production bug is fixed.

Evidence: tests/test_mcp_server.py:75 uses monkeypatch.setattr(mcp_server.TiDBClient, "connect", staticmethod(fake_connect)) and line 88 asserts "database_url": None. This test validates the broken behavior instead of the correct API.


mcp_server.TiDBConnector(
host="127.0.0.1",
port=4000,
username="root",
password="",
database="test",
query_timeout=7,
)

assert calls == [
{
"database_url": None,
"host": "127.0.0.1",
"port": 4000,
"username": "root",
"password": "",
"database": "test",
"connect_args": {
"init_command": "SET SESSION max_execution_time = 7000"
},
}
]


def test_tidb_connector_preserves_query_timeout_when_switching_databases(monkeypatch):
calls = []

def fake_connect(**kwargs):
client = Mock()
client.disconnect = Mock()
calls.append(kwargs)
return client

monkeypatch.setattr(mcp_server.TiDBClient, "connect", staticmethod(fake_connect))

connector = mcp_server.TiDBConnector(
host="127.0.0.1",
port=4000,
username="root",
password="",
database="test",
query_timeout=9,
)

connector.switch_database("analytics")

assert calls[1] == {
"host": "127.0.0.1",
"port": 4000,
"username": "root",
"password": "",
"database": "analytics",
"connect_args": {
"init_command": "SET SESSION max_execution_time = 9000"
},
}


def test_app_lifespan_passes_query_timeout(monkeypatch):
captured = {}

class FakeConnector:
def __init__(self, **kwargs):
captured.update(kwargs)
self.host = kwargs["host"]
self.port = kwargs["port"]
self.database = kwargs["database"]

def disconnect(self):
captured["disconnected"] = True

monkeypatch.setattr(mcp_server, "TiDBConnector", FakeConnector)
monkeypatch.setattr(mcp_server, "MCP_QUERY_TIMEOUT", 15)

async def run_lifespan():
async with mcp_server.app_lifespan(Mock()):
pass

import asyncio

asyncio.run(run_lifespan())

assert captured["query_timeout"] == 15
assert captured["disconnected"] is True