Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"openai>=1.65.3",
"pandas>=2.2.3",
"pandas-stubs==2.3.0.250703",
"prometheus-client>=0.22.1",
"prompt-toolkit>=3.0.50",
"pydantic>=2.10.6",
"pydantic-settings>=2.8.1",
Expand Down
15 changes: 15 additions & 0 deletions src/dremioai/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (C) 2017-2025 Dremio Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
23 changes: 23 additions & 0 deletions src/dremioai/metrics/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#
# Copyright (C) 2017-2025 Dremio Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from prometheus_client import CollectorRegistry, make_asgi_app

_registry = CollectorRegistry()


def get_metrics_app():
global _registry
return make_asgi_app(_registry)
30 changes: 30 additions & 0 deletions src/dremioai/metrics/tool_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Copyright (C) 2017-2025 Dremio Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from prometheus_client import Counter, Histogram
from dremioai.metrics.registry import _registry

invocation_counter = Counter(
"mcp_tool_invocations",
"Number of times a tool is invoked",
["tool", "project_id"],
registry=_registry,
)
invocation_duration = Histogram(
"mcp_tool_invocation_duration",
"Time taken to invoke a tool",
["tool", "project_id"],
registry=_registry,
)
4 changes: 4 additions & 0 deletions src/dremioai/servers/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from mcp.shared.auth import OAuthMetadata
from pydantic import AnyHttpUrl
from pydantic.networks import AnyUrl

from dremioai.metrics.registry import get_metrics_app
from starlette.requests import Request
from starlette.responses import Response

Expand Down Expand Up @@ -117,6 +119,8 @@ def streamable_http_app(self):
# like ../mcp/{project_id}/.. and extract that project id as
# context var
app.add_middleware(ProjectIdMiddleware)

app.mount("/metrics", get_metrics_app(), name="metrics")
return app

def __init__(self, *args, **kwargs):
Expand Down
26 changes: 24 additions & 2 deletions src/dremioai/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from contextvars import ContextVar
from typing import (
List,
Expand All @@ -34,7 +33,6 @@

from dataclasses import dataclass, asdict, field

from starlette.datastructures import URL
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request

Expand All @@ -55,6 +53,7 @@
from sqlglot import expressions
from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.provider import AccessToken
from dremioai.metrics.tool_metrics import invocation_counter, invocation_duration

logger = log.logger(__name__)

Expand Down Expand Up @@ -158,6 +157,23 @@ async def _impl(self, *args: P.args, **kw: P.kwargs) -> T:
return _impl


def with_metrics(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@functools.wraps(fn)
async def _impl(self, *args: P.args, **kw: P.kwargs) -> T:
project_id = None
if dremio := settings.instance().dremio:
project_id = dremio.project_id
invocation_counter.labels(
project_id=project_id, tool=self.__class__.__name__
).inc()
with invocation_duration.labels(
project_id=project_id, tool=self.__class__.__name__
).time():
return await fn(self, *args, **kw)

return _impl


def _get_class_var_hints(tool: Tools, name: str) -> bool:
if class_var := get_type_hints(tool, include_extras=True).get(name):
if cls_args := get_args(class_var):
Expand Down Expand Up @@ -193,6 +209,7 @@ def group_by(self, df, by):
return df.groupby(by).size().reset_index(name="count").to_dict(orient="records")

@secured
@with_metrics
async def invoke(self) -> Dict[str, Any]:
"""Get the stats and details of failed or canceled jobs executed in the Dremio cluster in the past 7 days
along with a split by job type
Expand Down Expand Up @@ -303,6 +320,7 @@ def ensure_query_allowed(s: str):
)

@secured
@with_metrics
async def invoke(self, s: str) -> Dict[str, Union[List[Dict[Any, Any]] | str]]:
"""Run a SELECT sql query on the Dremio cluster and return the results.
Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes
Expand All @@ -328,6 +346,7 @@ class BuildUsageReport(Tools):
project_id_required: ClassVar[Annotated[bool, True]]

@secured
@with_metrics
async def invoke(
self, by: Optional[Literal["PROJECT", "ENGINE"]] = "ENGINE"
) -> Dict[str, Any]:
Expand Down Expand Up @@ -391,6 +410,7 @@ class GetSchemaOfTable(Tools):
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]

@secured
@with_metrics
async def invoke(self, table_name: Union[str | List[str]]) -> Dict[str, Any]:
"""Gets the schema of the given table.

Expand All @@ -416,6 +436,7 @@ class GetTableOrViewLineage(Tools):
For: ClassVar[Annotated[ToolType, ToolType.FOR_SELF | ToolType.FOR_DATA_PATTERNS]]

@secured
@with_metrics
async def invoke(self, table_name: Union[str, List[str]]) -> Dict[str, Any]:
"""Finds the lineage of a table or view in the Dremio cluster

Expand All @@ -437,6 +458,7 @@ class SearchTableAndViews(Tools):
]

@secured
@with_metrics
async def invoke(self, query: str) -> Dict[str, Any]:
"""Runs a semantic search on the Dremio cluster to find tables and views that match the query.

Expand Down
25 changes: 25 additions & 0 deletions tests/e2e/test_e2e_pat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
#
import uuid

import httpx
import pytest
from mcp.types import CallToolResult
from conftest import http_streamable_client_server, http_streamable_mcp_server
from urllib.parse import urlparse
from dremioai.config import settings

from dremioai.metrics.tool_metrics import invocation_counter


@pytest.mark.asyncio
Expand Down Expand Up @@ -54,3 +59,23 @@ async def test_tool_pat(mock_config_dir, logging_server, logging_level, project_
assert (
str(project_id) in le.path
), f"{le} does not have the right project id"

async with httpx.AsyncClient() as client:
u = urlparse(sf.mcp_server.url)._replace(path="/metrics/").geturl()
r = await client.get(u, headers={"Authorization": "Bearer my-token"})
assert (
r.status_code == 200
), f"Error getting metrics: {r.text} status={r.status_code}"
if project_id is None:
project_id = settings.instance().dremio.project_id
for line in r.text.splitlines():
if (
line.startswith(f"{invocation_counter._name}_total{{")
and f'project_id="{project_id}"' in line
):
assert (
float(line.split()[-1]) == 1.0
), f"Invocation count not 1: {line}"
break
else:
assert False, f"Invocation count for {project_id} not found in {r.text}"
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.