Skip to content

Commit 53495be

Browse files
committed
format
1 parent 52eebbb commit 53495be

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

optuna_mcp/server.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import optuna
1111
import optuna_dashboard
1212
import plotly
13-
from pydantic import BaseModel, Field
13+
from pydantic import BaseModel
14+
from pydantic import Field
1415

1516

1617
class OptunaMCP(FastMCP):
@@ -59,17 +60,32 @@ class TrialToAdd:
5960
user_attrs: dict[str, typing.Any] | None
6061
system_attrs: dict[str, typing.Any] | None
6162

63+
6264
class StudyInfo(BaseModel):
6365
study_name: str
64-
sampler_name: typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"] | None = Field(default=None, description="The name of the sampler used in the study, if available.")
65-
directions: list[typing.Literal["minimize", "maximize"]] | None = Field(default=None, description="The optimization directions for each objective, if available.")
66+
sampler_name: (
67+
typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"] | None
68+
) = Field(default=None, description="The name of the sampler used in the study, if available.")
69+
directions: list[typing.Literal["minimize", "maximize"]] | None = Field(
70+
default=None, description="The optimization directions for each objective, if available."
71+
)
72+
6673

6774
class TrialInfo(BaseModel):
6875
trial_number: int
69-
params: dict[str, typing.Any] | None = Field(default = None, description="The parameter values suggested by the trial.")
70-
values: list[float] | None = Field(default = None, description="The objective values of the trial, if available.")
71-
user_attrs: dict[str, typing.Any] | None = Field(default = None, description="User-defined attributes for the trial, if any.")
72-
system_attrs: dict[str, typing.Any] | None = Field(default = None, description="System-defined attributes for the trial, if any.")
76+
params: dict[str, typing.Any] | None = Field(
77+
default=None, description="The parameter values suggested by the trial."
78+
)
79+
values: list[float] | None = Field(
80+
default=None, description="The objective values of the trial, if available."
81+
)
82+
user_attrs: dict[str, typing.Any] | None = Field(
83+
default=None, description="User-defined attributes for the trial, if any."
84+
)
85+
system_attrs: dict[str, typing.Any] | None = Field(
86+
default=None, description="System-defined attributes for the trial, if any."
87+
)
88+
7389

7490
def register_tools(mcp: OptunaMCP) -> OptunaMCP:
7591
@mcp.tool()
@@ -94,7 +110,7 @@ def create_study(
94110
return StudyInfo(study_name=study_name)
95111

96112
@mcp.tool()
97-
def get_all_study_names() -> list[StudyInfo]:
113+
def get_all_study_names() -> list[StudyInfo] | str:
98114
"""Get all study names from the storage."""
99115
storage: str | optuna.storages.BaseStorage | None = None
100116
if mcp.study is not None:
@@ -108,7 +124,7 @@ def get_all_study_names() -> list[StudyInfo]:
108124
return [StudyInfo(study_name=name) for name in study_names]
109125

110126
@mcp.tool()
111-
def ask(search_space: dict) -> TrialInfo:
127+
def ask(search_space: dict) -> TrialInfo | str:
112128
"""Suggest new parameters using Optuna
113129
114130
search_space must be a string that can be evaluated to a dictionary to specify Optuna's distributions.
@@ -149,7 +165,7 @@ def tell(trial_number: int, values: float | list[float]) -> TrialInfo:
149165
return TrialInfo(
150166
trial_number=trial_number,
151167
values=[values] if isinstance(values, float) else values,
152-
)
168+
)
153169

154170
@mcp.tool()
155171
def set_sampler(
@@ -267,13 +283,16 @@ def best_trials() -> list[TrialInfo]:
267283
"""Return trials located at the Pareto front in the study."""
268284
if mcp.study is None:
269285
raise ValueError("No study has been created. Please create a study first.")
270-
return [TrialInfo(
271-
trial_number=trial.number,
272-
params=trial.params,
273-
values=trial.values,
274-
user_attrs=trial.user_attrs,
275-
system_attrs=trial.system_attrs,
276-
) for trial in mcp.study.best_trials]
286+
return [
287+
TrialInfo(
288+
trial_number=trial.number,
289+
params=trial.params,
290+
values=trial.values,
291+
user_attrs=trial.user_attrs,
292+
system_attrs=trial.system_attrs,
293+
)
294+
for trial in mcp.study.best_trials
295+
]
277296

278297
def _create_trial(trial: TrialToAdd) -> optuna.trial.FrozenTrial:
279298
"""Create a trial from the given parameters."""

0 commit comments

Comments
 (0)