Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 96 additions & 34 deletions optuna_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import optuna
import optuna_dashboard
import plotly
from pydantic import BaseModel
from pydantic import Field


class OptunaMCP(FastMCP):
Expand Down Expand Up @@ -59,12 +61,38 @@ class TrialToAdd:
system_attrs: dict[str, typing.Any] | None


class StudyResponse(BaseModel):
study_name: str
sampler_name: (
typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"] | None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Notes] A nit, but supported samplers are listed in the definition of set_sampler as well. We may define SamplerName=typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"] to prevent the inconsistency.

But we can work on it in a follow-up PR.

) = Field(default=None, description="The name of the sampler used in the study, if available.")
directions: list[str] | None = Field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of the directions attribute is not consistent with the argument of create_study.

Suggested change
directions: list[str] | None = Field(
directions: list[typing.Literal["minimize", "maximize"]] | None = Field(

default=None, description="The optimization directions for each objective, if available."
)


class TrialResponse(BaseModel):
trial_number: int
params: dict[str, typing.Any] | None = Field(
default=None, description="The parameter values suggested by the trial."
)
values: list[float] | None = Field(
default=None, description="The objective values of the trial, if available."
)
user_attrs: dict[str, typing.Any] | None = Field(
default=None, description="User-defined attributes for the trial, if any."
)
system_attrs: dict[str, typing.Any] | None = Field(
default=None, description="System-defined attributes for the trial, if any."
)


def register_tools(mcp: OptunaMCP) -> OptunaMCP:
@mcp.tool()
@mcp.tool(structured_output=True)
def create_study(
study_name: str,
directions: list[typing.Literal["minimize", "maximize"]] | None = None,
) -> str:
) -> StudyResponse:
"""Create a new Optuna study with the given study_name and directions.

If the study already exists, it will be simply loaded.
Expand All @@ -79,10 +107,10 @@ def create_study(
if mcp.storage is None:
mcp.storage = mcp.study._storage

return f"Optuna study {study_name} has been prepared"
return StudyResponse(study_name=study_name)

@mcp.tool()
def get_all_study_names() -> str:
@mcp.tool(structured_output=True)
def get_all_study_names() -> list[StudyResponse] | str:
"""Get all study names from the storage."""
storage: str | optuna.storages.BaseStorage | None = None
if mcp.study is not None:
Expand All @@ -93,10 +121,10 @@ def get_all_study_names() -> str:
return "No storage specified."

study_names = optuna.get_all_study_names(storage)
return f"All study names: {study_names}"
return [StudyResponse(study_name=name) for name in study_names]

@mcp.tool()
def ask(search_space: dict) -> str:
@mcp.tool(structured_output=True)
def ask(search_space: dict) -> TrialResponse | str:
"""Suggest new parameters using Optuna

search_space must be a string that can be evaluated to a dictionary to specify Optuna's distributions.
Expand All @@ -117,10 +145,13 @@ def ask(search_space: dict) -> str:

trial = mcp.study.ask(fixed_distributions=distributions)

return f"Trial {trial.number} suggested: {json.dumps(trial.params)}"
return TrialResponse(
trial_number=trial.number,
params=trial.params,
)

@mcp.tool()
def tell(trial_number: int, values: float | list[float]) -> str:
@mcp.tool(structured_output=True)
def tell(trial_number: int, values: float | list[float]) -> TrialResponse:
"""Report the result of a trial"""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
Expand All @@ -131,12 +162,15 @@ def tell(trial_number: int, values: float | list[float]) -> str:
state=optuna.trial.TrialState.COMPLETE,
skip_if_finished=True,
)
return f"Trial {trial_number} reported with values {json.dumps(values)}"
return TrialResponse(
trial_number=trial_number,
values=[values] if isinstance(values, float) else values,
)

@mcp.tool()
@mcp.tool(structured_output=True)
def set_sampler(
name: typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"],
) -> str:
) -> StudyResponse:
"""Set the sampler for the study.
The sampler must be one of the following:
- TPESampler
Expand All @@ -152,9 +186,12 @@ def set_sampler(
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
mcp.study.sampler = sampler
return f"Sampler set to {name}"
return StudyResponse(
study_name=mcp.study.study_name,
sampler_name=name,
)

@mcp.tool()
@mcp.tool(structured_output=True)
def set_trial_user_attr(trial_number: int, key: str, value: typing.Any) -> str:
"""Set user attributes for a trial"""
if mcp.study is None:
Expand All @@ -167,8 +204,8 @@ def set_trial_user_attr(trial_number: int, key: str, value: typing.Any) -> str:
storage.set_trial_user_attr(trial_id, key, value)
return f"User attribute {key} set to {json.dumps(value)} for trial {trial_number}"

@mcp.tool()
def get_trial_user_attrs(trial_number: int) -> str:
@mcp.tool(structured_output=True)
def get_trial_user_attrs(trial_number: int) -> TrialResponse:
"""Get user attributes in a trial"""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
Expand All @@ -177,9 +214,13 @@ def get_trial_user_attrs(trial_number: int) -> str:
mcp.study._study_id, trial_number
)
trial = storage.get_trial(trial_id)
return f"User attributes in trial {trial_number}: {json.dumps(trial.user_attrs)}"
# return f"User attributes in trial {trial_number}: {json.dumps(trial.user_attrs)}"
return TrialResponse(
trial_number=trial_number,
user_attrs=trial.user_attrs,
)

@mcp.tool()
@mcp.tool(structured_output=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Notes] Since metric names correspond to the study's directions, we might be able to include the response in StudyResponse as follows:

StudyResponse:
    ...
    metric_names: | None = Field(default=None, description="...")

But this new approach may need further discussion, so we can work on it in a follow-up PR.

def set_metric_names(metric_names: list[str]) -> str:
"""Set metric_names. metric_names are labels used to distinguish what each objective value is.

Expand All @@ -193,31 +234,34 @@ def set_metric_names(metric_names: list[str]) -> str:
mcp.study.set_metric_names(metric_names)
return f"metric_names set to {json.dumps(metric_names)}"

@mcp.tool()
@mcp.tool(structured_output=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

def get_metric_names() -> str:
"""Get metric_names"""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
return f"Metric names: {json.dumps(mcp.study.metric_names)}"

@mcp.tool()
def get_directions() -> str:
@mcp.tool(structured_output=True)
def get_directions() -> StudyResponse:
"""Get the directions of the study."""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
directions = [d.name.lower() for d in mcp.study.directions]
return f"Directions: {json.dumps(directions)}"
return StudyResponse(
study_name=mcp.study.study_name,
directions=directions,
)

@mcp.tool()
@mcp.tool(structured_output=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative approach is to return the list of TrialResponse.
So, how about keeping the structured_output=False for future updates?

Suggested change
@mcp.tool(structured_output=True)
@mcp.tool()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of MCP sets the default value of structured_output to None, which automatically chose the return type (either structured or unstructured). In the case of get_trials, structured output is returned if structured_output is not specified.

To preserve the current behavior of main for future updates, I explicitly set structured_output=False. This also simplifies the test cases, as mentioned in this discussion.

def get_trials() -> str:
"""Get all trials in a CSV format"""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
csv_string = mcp.study.trials_dataframe().to_csv()
return f"Trials: \n{csv_string}"

@mcp.tool()
def best_trial() -> str:
@mcp.tool(structured_output=True)
def best_trial() -> TrialResponse:
"""Get the best trial

This feature can only be used for single-objective optimization. If your study is multi-objective, use best_trials instead.
Expand All @@ -226,14 +270,29 @@ def best_trial() -> str:
raise ValueError("No study has been created. Please create a study first.")

trial = mcp.study.best_trial
return f"Best trial: {trial.number} with params {json.dumps(trial.params)} and value {json.dumps(trial.value)}"
return TrialResponse(
trial_number=trial.number,
params=trial.params,
values=trial.values,
user_attrs=trial.user_attrs,
system_attrs=trial.system_attrs,
)

@mcp.tool()
def best_trials() -> str:
@mcp.tool(structured_output=True)
def best_trials() -> list[TrialResponse]:
"""Return trials located at the Pareto front in the study."""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
return f"Best trials: {mcp.study.best_trials}"
return [
TrialResponse(
trial_number=trial.number,
params=trial.params,
values=trial.values,
user_attrs=trial.user_attrs,
system_attrs=trial.system_attrs,
)
Comment on lines +363 to +369
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I found that the proposed best_trial returns more information, such as user_attrs and system_attrs, than the previous implementation, while the proposed best_trials returns less information, such as distributions, datetime_start, datetime_end, and intermediate_values.

This PR did not introduce this gap between best_trial and best_trials, and we can address this issue in a follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing that out. At the moment, I don’t have a strong preference regarding what should be returned in best_trial and best_trials. This issue will be addressed in a follow-up PR.

for trial in mcp.study.best_trials
]

def _create_trial(trial: TrialToAdd) -> optuna.trial.FrozenTrial:
"""Create a trial from the given parameters."""
Expand All @@ -249,15 +308,15 @@ def _create_trial(trial: TrialToAdd) -> optuna.trial.FrozenTrial:
system_attrs=trial.system_attrs,
)

@mcp.tool()
@mcp.tool(structured_output=True)
def add_trial(trial: TrialToAdd) -> str:
"""Add a trial to the study."""
if mcp.study is None:
raise ValueError("No study has been created. Please create a study first.")
mcp.study.add_trial(_create_trial(trial))
return "Trial was added."

@mcp.tool()
@mcp.tool(structured_output=True)
def add_trials(trials: list[TrialToAdd]) -> str:
"""Add multiple trials to the study."""
frozen_trials = [_create_trial(trial) for trial in trials]
Expand Down Expand Up @@ -482,7 +541,7 @@ def plot_rank(
fig = optuna.visualization.plot_rank(mcp.study)
return Image(data=plotly.io.to_image(fig), format="png")

@mcp.tool()
@mcp.tool(structured_output=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we can define the response type for the optuna dashboard, such as:

class OptunaDashboardResponse(BaseModel):
    url: str = Field(description="The URL of the Optuna dashboard.")

But this lacks information on whether the server has been newly started or not, so we can discuss this approach in a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining a response type for the Optuna dashboard sounds like a good idea. This will also be addressed in a follow-up PR.

def launch_optuna_dashboard(port: int = 58080) -> str:
"""Launch the Optuna dashboard"""
storage: str | optuna.storages.BaseStorage | None = None
Expand Down Expand Up @@ -540,3 +599,6 @@ def main() -> None:

if __name__ == "__main__":
main()

mcp = OptunaMCP("Optuna")
mcp = register_tools(mcp)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
requires-python = ">=3.12"
dependencies = [
"kaleido==0.2.1",
"mcp[cli]>=1.5.0",
"mcp[cli]>=1.10.0",
"optuna>=4.2.1",
"optuna-dashboard>=0.18.0",
"pandas>=2.2.3",
Expand Down
Loading