-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy path__main__.py
More file actions
213 lines (169 loc) · 7.3 KB
/
__main__.py
File metadata and controls
213 lines (169 loc) · 7.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import logging
import os
from typing import Any, Optional
import click
from datahub.ingestion.graph.config import ClientMode, DatahubClientConfig
from datahub.sdk.main_client import DataHubClient
from datahub.telemetry import telemetry
from fastmcp import FastMCP
from fastmcp.server.auth import TokenVerifier
from fastmcp.server.auth.auth import AccessToken
from fastmcp.server.dependencies import get_http_request
from fastmcp.server.middleware import Middleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from typing_extensions import Literal
from mcp_server_datahub._telemetry import TelemetryMiddleware
from mcp_server_datahub._version import __version__
from mcp_server_datahub.document_tools_middleware import DocumentToolsMiddleware
from mcp_server_datahub.mcp_server import mcp, register_all_tools, with_datahub_client
from mcp_server_datahub.version_requirements import VersionFilterMiddleware
logging.basicConfig(level=logging.INFO)
# Register tools with OSS-compatible descriptions
register_all_tools(is_oss=True)
_GET_ME_QUERY = "query getMe { me { corpUser { urn username } } }"
def _build_client(server_url: str, token: str) -> DataHubClient:
return DataHubClient(
config=DatahubClientConfig(
server=server_url,
token=token,
client_mode=ClientMode.SDK,
datahub_component=f"mcp-server-datahub/{__version__}",
)
)
def _verify_client(client: DataHubClient) -> None:
"""Verify the client can authenticate by calling the me query."""
client._graph.execute_graphql(_GET_ME_QUERY)
def _token_from_request() -> Optional[str]:
"""Extract a DataHub token from the current HTTP request.
Reads the ``Authorization: Bearer <token>`` header.
"""
try:
request = get_http_request()
except RuntimeError:
return None
auth = request.headers.get("authorization", "")
if auth.startswith("Bearer "):
return auth[len("Bearer ") :]
return None
class _DataHubTokenVerifier(TokenVerifier):
"""FastMCP TokenVerifier that validates DataHub bearer tokens.
Called by FastMCP's BearerAuthBackend for every HTTP request that carries
an Authorization: Bearer header. If the token is valid a synthetic
AccessToken is returned; otherwise None causes FastMCP to reply with
401 WWW-Authenticate: Bearer automatically.
"""
def __init__(self, server_url: str) -> None:
super().__init__()
self._server_url = server_url
async def verify_token(self, token: str) -> Optional[AccessToken]:
try:
client = _build_client(self._server_url, token)
_verify_client(client)
return AccessToken(
client_id=f"mcp-server-datahub/{__version__}", scopes=[], token=token
)
except Exception:
return None
class _DataHubClientMiddleware(Middleware):
"""Middleware that propagates the DataHub client ContextVar into each request.
When running with HTTP transport (stateless_http=True), each request is handled
in a separate async context that does not inherit ContextVars from the main
thread. This middleware ensures the DataHub client is available in every request
context by setting the ContextVar at the start of each MCP message.
Token validation is handled upstream by ``_DataHubTokenVerifier`` for Bearer
tokens. This middleware only needs to build the client for the current request
(or fall back to the default token when a global token is configured).
Must be added as the first middleware so it wraps all other middlewares.
"""
def __init__(self, server_url: str, default_token: Optional[str] = None) -> None:
self._server_url = server_url
self._default_token = default_token
def _client_for_request(self) -> DataHubClient:
token = _token_from_request()
if token is not None:
# Token already validated by _DataHubTokenVerifier.
return _build_client(self._server_url, token)
if self._default_token is not None:
return _build_client(self._server_url, self._default_token)
raise ValueError(
"No DataHub token provided. Supply a token via the Authorization header."
)
async def on_message(
self,
context: Any,
call_next: Any,
) -> Any:
with with_datahub_client(self._client_for_request()):
return await call_next(context)
# Adds a health route to the MCP Server.
# Notice that this is only available when the MCP Server is run in HTTP/SSE modes.
# Doesn't make much sense to have it in the stdio mode since it is usually used as a subprocess of the client.
@mcp.custom_route("/health", methods=["GET"])
async def health(request: Request) -> Response:
return JSONResponse({"status": "ok"})
_app_initialized = False
def create_app() -> FastMCP:
"""Create and configure the MCP server with DataHub client and middlewares.
This is the factory function used by ``fastmcp dev`` / ``fastmcp run``
(via the ``__main__.py:create_app`` entrypoint) and is also called by the
CLI ``main()`` entrypoint.
The function is idempotent — calling it more than once returns the same
``mcp`` instance without adding duplicate middlewares.
"""
global _app_initialized
if _app_initialized:
return mcp
server_url = os.environ.get("DATAHUB_GMS_URL")
if not server_url:
raise RuntimeError("DATAHUB_GMS_URL environment variable is required.")
global_token = os.environ.get("DATAHUB_GMS_TOKEN")
if global_token:
_verify_client(_build_client(server_url, global_token))
# _DataHubClientMiddleware must be first so the client ContextVar is
# available to all subsequent middlewares and tool handlers. This is
# especially important for HTTP transport where each request runs in a
# separate async context.
mcp.add_middleware(_DataHubClientMiddleware(server_url, global_token))
mcp.add_middleware(TelemetryMiddleware())
mcp.add_middleware(VersionFilterMiddleware())
mcp.add_middleware(DocumentToolsMiddleware())
_app_initialized = True
return mcp
@click.command()
@click.version_option(version=__version__)
@click.option(
"--transport",
type=click.Choice(["stdio", "sse", "http"]),
default="stdio",
)
@click.option(
"--debug",
is_flag=True,
default=False,
)
@telemetry.with_telemetry(
capture_kwargs=["transport"],
)
def main(transport: Literal["stdio", "sse", "http"], debug: bool) -> None:
if debug:
# Add LoggingMiddleware before create_app() so it becomes the
# outermost middleware (FastMCP reverses the list) and logs the
# full request/response including all other middleware effects.
mcp.add_middleware(LoggingMiddleware(include_payloads=True))
create_app()
if transport == "http":
server_url = os.environ.get("DATAHUB_GMS_URL", "")
if not os.environ.get("DATAHUB_GMS_TOKEN"):
mcp.auth = _DataHubTokenVerifier(server_url)
mcp.run(
transport=transport,
show_banner=False,
stateless_http=True,
host="0.0.0.0",
)
else:
mcp.run(transport=transport, show_banner=False)
if __name__ == "__main__":
main()