Skip to content

Commit cef3a87

Browse files
authored
[Tools] Add two new tools - list_runs and run_function (#3)
This PR adds two new tools: - `run_function`: allows users to trigger a function through the MCP - `list_runs`: provides a detailed summary of previous runs
1 parent a96e5b2 commit cef3a87

File tree

2 files changed

+301
-1
lines changed

2 files changed

+301
-1
lines changed

src/mcp_server/mlrun_client/api_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import mlrun.db.auth_utils
3535
import mlrun.db.httpdb
36+
import mlrun_pipelines.models
3637

3738
if typing.TYPE_CHECKING:
3839
import mlrun.artifacts
@@ -118,6 +119,23 @@ async def get_artifact(self, project_name: str, *args, **kwargs) -> mlrun.artifa
118119
project = self.__get_project(project_name)
119120
return await asyncio.to_thread(project.get_artifact, *args, **kwargs)
120121

122+
123+
async def run_function(self, project_name: str, *args, **kwargs) -> typing.Union[mlrun.model.RunObject, mlrun_pipelines.models.PipelineNodeWrapper]:
124+
"""Runs a function in the given project
125+
126+
Returns the function runtime after execution.
127+
"""
128+
project = self.__get_project(project_name)
129+
return await asyncio.to_thread(project.run_function, *args, **kwargs)
130+
131+
async def list_runs(self, project_name: str, *args, **kwargs) -> list:
132+
"""Lists runs of a given project
133+
134+
Returns a list of runs.
135+
"""
136+
project = self.__get_project(project_name)
137+
return await asyncio.to_thread(project.list_runs, *args, **kwargs)
138+
121139
async def list_model_endpoints(self, project_name: str, *args, **kwargs) -> list[dict]:
122140
"""Lists model endpoints of a given project
123141

src/mcp_server/server.py

Lines changed: 283 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
from __future__ import annotations
2323

24+
import datetime
2425
from typing import Annotated
2526

26-
import datetime
2727
import fastmcp
2828
from fastmcp.server.dependencies import AccessToken, get_access_token
2929
from pydantic import Field
@@ -101,6 +101,288 @@ async def get_artifact_uri(
101101
async with get_mlrun_api_client() as client:
102102
return (await client.get_artifact(project_name, key=key, tag=tag)).uri
103103

104+
@mcp.tool
105+
async def run_function(
106+
function_name: Annotated[str, Field(description="Name of the function to run (e.g., 'my_function')")],
107+
project_name: project_name_field,
108+
handler: Annotated[
109+
str | None,
110+
Field(
111+
description="Name of the function handler",
112+
default=None,
113+
),
114+
] = None,
115+
name: Annotated[
116+
str,
117+
Field(
118+
description="Execution name",
119+
default="",
120+
),
121+
] = "",
122+
params: Annotated[
123+
dict | None,
124+
Field(
125+
description="Input parameters (dict) to pass to the function (e.g., {'param1': 'value1'})",
126+
default=None,
127+
),
128+
] = None,
129+
inputs: Annotated[
130+
dict | None,
131+
Field(
132+
description="Input objects to pass to the handler (e.g., {'dataset': 'store://datasets/my_dataset'})",
133+
default=None,
134+
),
135+
] = None,
136+
labels: Annotated[
137+
dict | None,
138+
Field(
139+
description="Labels to tag the job/run with (e.g., {'key1': 'val1', 'key2': 'val2'})",
140+
default=None,
141+
),
142+
] = None,
143+
workdir: Annotated[
144+
str,
145+
Field(
146+
description="Working directory of the executed job and the default path for artifact inputs",
147+
default="",
148+
),
149+
] = "",
150+
schedule: Annotated[
151+
str | None,
152+
Field(
153+
description="Standard crontab expression string for scheduling (e.g., '0 0 * * *' for daily at midnight)",
154+
default=None,
155+
),
156+
] = None,
157+
output_path: Annotated[
158+
str | None,
159+
Field(
160+
description="Path to store artifacts, when running in a workflow this will be set automatically",
161+
default=None,
162+
),
163+
] = None,
164+
local: Annotated[
165+
bool | None,
166+
Field(
167+
description="Run the function locally vs on the runtime/cluster",
168+
default=None,
169+
),
170+
] = None,
171+
watch: Annotated[
172+
bool,
173+
Field(
174+
description="Watch/follow run log",
175+
default=True,
176+
),
177+
] = True,
178+
) -> dict:
179+
"""
180+
Run a function in the specified project.
181+
182+
This tool executes an MLRun function and returns information about the run.
183+
The function can be run locally or on a remote runtime/cluster.
184+
"""
185+
async with get_mlrun_api_client() as client:
186+
run_result = await client.run_function(
187+
project_name,
188+
function=function_name,
189+
handler=handler,
190+
name=name,
191+
params=params,
192+
inputs=inputs,
193+
labels=labels,
194+
workdir=workdir,
195+
schedule=schedule,
196+
output_path=output_path,
197+
local=local,
198+
watch=watch,
199+
)
200+
201+
# Extract relevant information from the run result
202+
return {
203+
"uid": run_result.metadata.uid if hasattr(run_result, 'metadata') else None,
204+
"name": run_result.metadata.name if hasattr(run_result, 'metadata') else None,
205+
"project": run_result.metadata.project if hasattr(run_result, 'metadata') else None,
206+
"state": run_result.status.state if hasattr(run_result, 'status') else None,
207+
"results": run_result.status.results if hasattr(run_result, 'status') else None,
208+
"artifacts": [
209+
{
210+
"key": artifact.get("key"),
211+
"kind": artifact.get("kind"),
212+
}
213+
for artifact in (run_result.status.artifacts if hasattr(run_result, 'status') and run_result.status.artifacts else [])
214+
],
215+
}
216+
217+
218+
@mcp.tool
219+
async def list_runs(
220+
project_name: project_name_field,
221+
name: Annotated[
222+
str | None,
223+
Field(
224+
description="Name of the run to retrieve",
225+
default=None,
226+
),
227+
] = None,
228+
uid: Annotated[
229+
str | None,
230+
Field(
231+
description="Unique ID of the run (single UID as string)",
232+
default=None,
233+
),
234+
] = None,
235+
labels: Annotated[
236+
str | None,
237+
Field(
238+
description="Filter runs by labels. Format: 'key1=value1,key2=value2' or 'key1,key2' for key existence",
239+
default=None,
240+
),
241+
] = None,
242+
states: Annotated[
243+
str | None,
244+
Field(
245+
description="Comma-separated list of states to filter by (e.g., 'completed,running')",
246+
default=None,
247+
),
248+
] = None,
249+
sort: Annotated[
250+
bool,
251+
Field(
252+
description="Whether to sort the result according to their start time",
253+
default=True,
254+
),
255+
] = True,
256+
iter: Annotated[
257+
bool,
258+
Field(
259+
description="If True return runs from all iterations. Otherwise, return only runs whose iter is 0",
260+
default=False,
261+
),
262+
] = False,
263+
start_time_from: Annotated[
264+
str | None,
265+
Field(
266+
description="Filter by run start time from (ISO format with timezone, e.g., '2025-11-24T00:00:00Z')",
267+
default=None,
268+
),
269+
] = None,
270+
start_time_to: Annotated[
271+
str | None,
272+
Field(
273+
description="Filter by run start time to (ISO format with timezone, e.g., '2025-11-27T23:59:59Z')",
274+
default=None,
275+
),
276+
] = None,
277+
last_update_time_from: Annotated[
278+
str | None,
279+
Field(
280+
description="Filter by run last update time from (ISO format with timezone)",
281+
default=None,
282+
),
283+
] = None,
284+
last_update_time_to: Annotated[
285+
str | None,
286+
Field(
287+
description="Filter by run last update time to (ISO format with timezone)",
288+
default=None,
289+
),
290+
] = None,
291+
end_time_from: Annotated[
292+
str | None,
293+
Field(
294+
description="Filter by run end time from (ISO format with timezone)",
295+
default=None,
296+
),
297+
] = None,
298+
end_time_to: Annotated[
299+
str | None,
300+
Field(
301+
description="Filter by run end time to (ISO format with timezone)",
302+
default=None,
303+
),
304+
] = None,
305+
) -> list[dict]:
306+
"""
307+
Retrieve a list of runs.
308+
309+
The default returns the runs from the last week, partitioned by name.
310+
To override the default, specify any filter.
311+
"""
312+
# Parse labels parameter
313+
labels_param = None
314+
if labels:
315+
# Check if it's key=value format or just keys
316+
if '=' in labels:
317+
# Parse as dict: "key1=value1,key2=value2"
318+
labels_param = {}
319+
for label in labels.split(','):
320+
if '=' in label:
321+
key, value = label.split('=', 1)
322+
labels_param[key.strip()] = value.strip()
323+
else:
324+
# Key existence check
325+
labels_param[label.strip()] = None
326+
else:
327+
# Parse as list of keys: "key1,key2"
328+
labels_param = [key.strip() for key in labels.split(',')]
329+
330+
# Parse states parameter
331+
states_list = None
332+
if states:
333+
states_list = [state.strip() for state in states.split(',')]
334+
335+
# Parse datetime parameters
336+
def parse_datetime(dt_str: str | None) -> datetime.datetime | None:
337+
if not dt_str:
338+
return None
339+
try:
340+
return datetime.datetime.fromisoformat(dt_str.replace('Z', '+00:00'))
341+
except ValueError:
342+
raise ValueError(
343+
f"Invalid datetime format: '{dt_str}'. "
344+
"Expected ISO format with timezone (e.g., '2025-11-24T00:00:00Z')"
345+
)
346+
347+
start_time_from_dt = parse_datetime(start_time_from)
348+
start_time_to_dt = parse_datetime(start_time_to)
349+
last_update_time_from_dt = parse_datetime(last_update_time_from)
350+
last_update_time_to_dt = parse_datetime(last_update_time_to)
351+
end_time_from_dt = parse_datetime(end_time_from)
352+
end_time_to_dt = parse_datetime(end_time_to)
353+
354+
async with get_mlrun_api_client() as client:
355+
runs = await client.list_runs(
356+
project_name,
357+
name=name,
358+
uid=uid,
359+
labels=labels_param,
360+
states=states_list,
361+
sort=sort,
362+
iter=iter,
363+
start_time_from=start_time_from_dt,
364+
start_time_to=start_time_to_dt,
365+
last_update_time_from=last_update_time_from_dt,
366+
last_update_time_to=last_update_time_to_dt,
367+
end_time_from=end_time_from_dt,
368+
end_time_to=end_time_to_dt,
369+
)
370+
371+
# Convert RunList to list of dicts
372+
return [
373+
{
374+
"uid": run.get("metadata", {}).get("uid"),
375+
"name": run.get("metadata", {}).get("name"),
376+
"project": run.get("metadata", {}).get("project"),
377+
"state": run.get("status", {}).get("state"),
378+
"start_time": run.get("status", {}).get("start_time"),
379+
"last_update": run.get("status", {}).get("last_update"),
380+
"labels": run.get("metadata", {}).get("labels", {}),
381+
"results": run.get("status", {}).get("results"),
382+
}
383+
for run in runs
384+
]
385+
104386

105387
@mcp.tool
106388
async def get_model_endpoints(

0 commit comments

Comments
 (0)