Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions optuna_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import optuna
import optuna_dashboard
import plotly
from pydantic import BaseModel
from pydantic import Field


class OptunaMCP(FastMCP):
Expand Down Expand Up @@ -59,12 +61,38 @@ class TrialToAdd:
system_attrs: dict[str, typing.Any] | None


class StudyInfo(BaseModel):
study_name: str
sampler_name: (
typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"] | None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Notes] A nit, but supported samplers are listed in the definition of set_sampler as well. We may define SamplerName=typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"] to prevent the inconsistency.

But we can work on it in a follow-up PR.

) = Field(default=None, description="The name of the sampler used in the study, if available.")
directions: list[typing.Literal["minimize", "maximize"]] | None = Field(
default=None, description="The optimization directions for each objective, if available."
)


class TrialInfo(BaseModel):
trial_number: int
params: dict[str, typing.Any] | None = Field(
default=None, description="The parameter values suggested by the trial."
)
values: list[float] | None = Field(
default=None, description="The objective values of the trial, if available."
)
user_attrs: dict[str, typing.Any] | None = Field(
default=None, description="User-defined attributes for the trial, if any."
)
system_attrs: dict[str, typing.Any] | None = Field(
default=None, description="System-defined attributes for the trial, if any."
)


def register_tools(mcp: OptunaMCP) -> OptunaMCP:
@mcp.tool()
def create_study(
study_name: str,
directions: list[typing.Literal["minimize", "maximize"]] | None = None,
) -> str:
) -> StudyInfo:
"""Create a new Optuna study with the given study_name and directions.

If the study already exists, it will be simply loaded.
Expand All @@ -79,10 +107,10 @@ def create_study(
if mcp.storage is None:
mcp.storage = mcp.study._storage

return f"Optuna study {study_name} has been prepared"
return StudyInfo(study_name=study_name)

@mcp.tool()
def get_all_study_names() -> str:
def get_all_study_names() -> list[StudyInfo] | str:
"""Get all study names from the storage."""
storage: str | optuna.storages.BaseStorage | None = None
if mcp.study is not None:
Expand All @@ -93,10 +121,10 @@ def get_all_study_names() -> str:
return "No storage specified."

study_names = optuna.get_all_study_names(storage)
return f"All study names: {study_names}"
return [StudyInfo(study_name=name) for name in study_names]

@mcp.tool()
def ask(search_space: dict) -> str:
def ask(search_space: dict) -> TrialInfo | str:
"""Suggest new parameters using Optuna

search_space must be a string that can be evaluated to a dictionary to specify Optuna's distributions.
Expand All @@ -117,10 +145,13 @@ def ask(search_space: dict) -> str:

trial = mcp.study.ask(fixed_distributions=distributions)

return f"Trial {trial.number} suggested: {json.dumps(trial.params)}"
return TrialInfo(
trial_number=trial.number,
params=trial.params,
)

@mcp.tool()
def tell(trial_number: int, values: float | list[float]) -> str:
def tell(trial_number: int, values: float | list[float]) -> TrialInfo:
"""Report the result of a trial"""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
Expand All @@ -131,12 +162,15 @@ def tell(trial_number: int, values: float | list[float]) -> str:
state=optuna.trial.TrialState.COMPLETE,
skip_if_finished=True,
)
return f"Trial {trial_number} reported with values {json.dumps(values)}"
return TrialInfo(
trial_number=trial_number,
values=[values] if isinstance(values, float) else values,
)

@mcp.tool()
def set_sampler(
name: typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"],
) -> str:
) -> StudyInfo:
"""Set the sampler for the study.
The sampler must be one of the following:
- TPESampler
Expand All @@ -152,7 +186,10 @@ def set_sampler(
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
mcp.study.sampler = sampler
return f"Sampler set to {name}"
return StudyInfo(
study_name=mcp.study.study_name,
sampler_name=name,
)

@mcp.tool()
def set_trial_user_attr(trial_number: int, key: str, value: typing.Any) -> str:
Expand All @@ -168,7 +205,7 @@ def set_trial_user_attr(trial_number: int, key: str, value: typing.Any) -> str:
return f"User attribute {key} set to {json.dumps(value)} for trial {trial_number}"

@mcp.tool()
def get_trial_user_attrs(trial_number: int) -> str:
def get_trial_user_attrs(trial_number: int) -> TrialInfo:
"""Get user attributes in a trial"""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
Expand All @@ -177,7 +214,11 @@ def get_trial_user_attrs(trial_number: int) -> str:
mcp.study._study_id, trial_number
)
trial = storage.get_trial(trial_id)
return f"User attributes in trial {trial_number}: {json.dumps(trial.user_attrs)}"
# return f"User attributes in trial {trial_number}: {json.dumps(trial.user_attrs)}"
return TrialInfo(
trial_number=trial_number,
user_attrs=trial.user_attrs,
)

@mcp.tool()
def set_metric_names(metric_names: list[str]) -> str:
Expand All @@ -201,12 +242,15 @@ def get_metric_names() -> str:
return f"Metric names: {json.dumps(mcp.study.metric_names)}"

@mcp.tool()
def get_directions() -> str:
def get_directions() -> StudyInfo:
"""Get the directions of the study."""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
directions = [d.name.lower() for d in mcp.study.directions]
return f"Directions: {json.dumps(directions)}"
return StudyInfo(
study_name=mcp.study.study_name,
directions=directions,
)

@mcp.tool()
def get_trials() -> str:
Expand All @@ -217,7 +261,7 @@ def get_trials() -> str:
return f"Trials: \n{csv_string}"

@mcp.tool()
def best_trial() -> str:
def best_trial() -> TrialInfo:
"""Get the best trial

This feature can only be used for single-objective optimization. If your study is multi-objective, use best_trials instead.
Expand All @@ -226,14 +270,29 @@ def best_trial() -> str:
raise ValueError("No study has been created. Please create a study first.")

trial = mcp.study.best_trial
return f"Best trial: {trial.number} with params {json.dumps(trial.params)} and value {json.dumps(trial.value)}"
return TrialInfo(
trial_number=trial.number,
params=trial.params,
values=trial.values,
user_attrs=trial.user_attrs,
system_attrs=trial.system_attrs,
)

@mcp.tool()
def best_trials() -> str:
def best_trials() -> list[TrialInfo]:
"""Return trials located at the Pareto front in the study."""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
return f"Best trials: {mcp.study.best_trials}"
return [
TrialInfo(
trial_number=trial.number,
params=trial.params,
values=trial.values,
user_attrs=trial.user_attrs,
system_attrs=trial.system_attrs,
)
for trial in mcp.study.best_trials
]

def _create_trial(trial: TrialToAdd) -> optuna.trial.FrozenTrial:
"""Create a trial from the given parameters."""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"plotly>=6.0.1",
"torch>=2.7.0",
"bottle>=0.13.4",
"fastmcp>=2.12.2",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Question] Do we need to migrate the mcp library from the official mcp library to fastmcp to support structured output? mcp==1.14.0 seems to support the structured output feature.

https://pypi.org/project/mcp/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing it out. I’ve updated the MCP version to mcp[cli]>=1.10.0 and removed fastmcp.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your update. I've confirmed that the mcp package supports the structured output since 1.10.0 in https://pypi.org/project/mcp/1.10.0/ and https://github.com/modelcontextprotocol/python-sdk/releases/tag/v1.10.0. Since the current implementation uses Pydantic, we may need to apply the fix modelcontextprotocol/python-sdk#1099, which was included in mcp==1.11.0. I'll check the behavior and report results soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Notes] Ah, the change does not contain any field aliases. I believe that modelcontextprotocol/python-sdk#1099 is not required for optuna mcp.

]

[project.scripts]
Expand Down
Loading
Loading