-
Notifications
You must be signed in to change notification settings - Fork 21
Add Structured Output Support #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
42c2a25
52eebbb
53495be
923d837
f683534
89662bb
7b4e89d
0713969
83f9527
5434429
e001fd5
94eb06b
95e79b1
50d15a9
ac6b526
6518c7b
e6fe94a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,8 @@ | |
| import optuna | ||
| import optuna_dashboard | ||
| import plotly | ||
| from pydantic import BaseModel | ||
| from pydantic import Field | ||
|
|
||
|
|
||
| class OptunaMCP(FastMCP): | ||
|
|
@@ -59,12 +61,38 @@ class TrialToAdd: | |
| system_attrs: dict[str, typing.Any] | None | ||
|
|
||
|
|
||
| class StudyInfo(BaseModel): | ||
| study_name: str | ||
| sampler_name: ( | ||
| typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"] | None | ||
|
||
| ) = Field(default=None, description="The name of the sampler used in the study, if available.") | ||
| directions: list[typing.Literal["minimize", "maximize"]] | None = Field( | ||
| default=None, description="The optimization directions for each objective, if available." | ||
| ) | ||
|
|
||
|
|
||
| class TrialInfo(BaseModel): | ||
| trial_number: int | ||
| params: dict[str, typing.Any] | None = Field( | ||
| default=None, description="The parameter values suggested by the trial." | ||
| ) | ||
| values: list[float] | None = Field( | ||
| default=None, description="The objective values of the trial, if available." | ||
| ) | ||
| user_attrs: dict[str, typing.Any] | None = Field( | ||
| default=None, description="User-defined attributes for the trial, if any." | ||
| ) | ||
| system_attrs: dict[str, typing.Any] | None = Field( | ||
| default=None, description="System-defined attributes for the trial, if any." | ||
| ) | ||
|
|
||
|
|
||
| def register_tools(mcp: OptunaMCP) -> OptunaMCP: | ||
| @mcp.tool() | ||
| def create_study( | ||
| study_name: str, | ||
| directions: list[typing.Literal["minimize", "maximize"]] | None = None, | ||
| ) -> str: | ||
| ) -> StudyInfo: | ||
| """Create a new Optuna study with the given study_name and directions. | ||
|
|
||
| If the study already exists, it will be simply loaded. | ||
|
|
@@ -79,10 +107,10 @@ def create_study( | |
| if mcp.storage is None: | ||
| mcp.storage = mcp.study._storage | ||
|
|
||
| return f"Optuna study {study_name} has been prepared" | ||
| return StudyInfo(study_name=study_name) | ||
|
|
||
| @mcp.tool() | ||
| def get_all_study_names() -> str: | ||
| def get_all_study_names() -> list[StudyInfo] | str: | ||
| """Get all study names from the storage.""" | ||
| storage: str | optuna.storages.BaseStorage | None = None | ||
| if mcp.study is not None: | ||
|
|
@@ -93,10 +121,10 @@ def get_all_study_names() -> str: | |
| return "No storage specified." | ||
|
|
||
| study_names = optuna.get_all_study_names(storage) | ||
| return f"All study names: {study_names}" | ||
| return [StudyInfo(study_name=name) for name in study_names] | ||
|
|
||
| @mcp.tool() | ||
| def ask(search_space: dict) -> str: | ||
| def ask(search_space: dict) -> TrialInfo | str: | ||
| """Suggest new parameters using Optuna | ||
|
|
||
| search_space must be a string that can be evaluated to a dictionary to specify Optuna's distributions. | ||
|
|
@@ -117,10 +145,13 @@ def ask(search_space: dict) -> str: | |
|
|
||
| trial = mcp.study.ask(fixed_distributions=distributions) | ||
|
|
||
| return f"Trial {trial.number} suggested: {json.dumps(trial.params)}" | ||
| return TrialInfo( | ||
| trial_number=trial.number, | ||
| params=trial.params, | ||
| ) | ||
|
|
||
| @mcp.tool() | ||
| def tell(trial_number: int, values: float | list[float]) -> str: | ||
| def tell(trial_number: int, values: float | list[float]) -> TrialInfo: | ||
| """Report the result of a trial""" | ||
| if mcp.study is None: | ||
| raise ValueError("No study has been created. Please create a study first.") | ||
|
|
@@ -131,12 +162,15 @@ def tell(trial_number: int, values: float | list[float]) -> str: | |
| state=optuna.trial.TrialState.COMPLETE, | ||
| skip_if_finished=True, | ||
| ) | ||
| return f"Trial {trial_number} reported with values {json.dumps(values)}" | ||
| return TrialInfo( | ||
| trial_number=trial_number, | ||
| values=[values] if isinstance(values, float) else values, | ||
| ) | ||
|
|
||
| @mcp.tool() | ||
| def set_sampler( | ||
| name: typing.Literal["TPESampler", "NSGAIISampler", "RandomSampler", "GPSampler"], | ||
| ) -> str: | ||
| ) -> StudyInfo: | ||
| """Set the sampler for the study. | ||
| The sampler must be one of the following: | ||
| - TPESampler | ||
|
|
@@ -152,7 +186,10 @@ def set_sampler( | |
| if mcp.study is None: | ||
| raise ValueError("No study has been created. Please create a study first.") | ||
| mcp.study.sampler = sampler | ||
| return f"Sampler set to {name}" | ||
| return StudyInfo( | ||
| study_name=mcp.study.study_name, | ||
| sampler_name=name, | ||
| ) | ||
|
|
||
| @mcp.tool() | ||
| def set_trial_user_attr(trial_number: int, key: str, value: typing.Any) -> str: | ||
|
|
@@ -168,7 +205,7 @@ def set_trial_user_attr(trial_number: int, key: str, value: typing.Any) -> str: | |
| return f"User attribute {key} set to {json.dumps(value)} for trial {trial_number}" | ||
|
|
||
| @mcp.tool() | ||
| def get_trial_user_attrs(trial_number: int) -> str: | ||
| def get_trial_user_attrs(trial_number: int) -> TrialInfo: | ||
| """Get user attributes in a trial""" | ||
| if mcp.study is None: | ||
| raise ValueError("No study has been created. Please create a study first.") | ||
|
|
@@ -177,7 +214,11 @@ def get_trial_user_attrs(trial_number: int) -> str: | |
| mcp.study._study_id, trial_number | ||
| ) | ||
| trial = storage.get_trial(trial_id) | ||
| return f"User attributes in trial {trial_number}: {json.dumps(trial.user_attrs)}" | ||
| # return f"User attributes in trial {trial_number}: {json.dumps(trial.user_attrs)}" | ||
| return TrialInfo( | ||
| trial_number=trial_number, | ||
| user_attrs=trial.user_attrs, | ||
| ) | ||
|
|
||
| @mcp.tool() | ||
| def set_metric_names(metric_names: list[str]) -> str: | ||
|
|
@@ -201,12 +242,15 @@ def get_metric_names() -> str: | |
| return f"Metric names: {json.dumps(mcp.study.metric_names)}" | ||
|
|
||
| @mcp.tool() | ||
| def get_directions() -> str: | ||
| def get_directions() -> StudyInfo: | ||
| """Get the directions of the study.""" | ||
| if mcp.study is None: | ||
| raise ValueError("No study has been created. Please create a study first.") | ||
| directions = [d.name.lower() for d in mcp.study.directions] | ||
| return f"Directions: {json.dumps(directions)}" | ||
| return StudyInfo( | ||
| study_name=mcp.study.study_name, | ||
| directions=directions, | ||
| ) | ||
|
|
||
| @mcp.tool() | ||
| def get_trials() -> str: | ||
|
|
@@ -217,7 +261,7 @@ def get_trials() -> str: | |
| return f"Trials: \n{csv_string}" | ||
|
|
||
| @mcp.tool() | ||
| def best_trial() -> str: | ||
| def best_trial() -> TrialInfo: | ||
| """Get the best trial | ||
|
|
||
| This feature can only be used for single-objective optimization. If your study is multi-objective, use best_trials instead. | ||
|
|
@@ -226,14 +270,29 @@ def best_trial() -> str: | |
| raise ValueError("No study has been created. Please create a study first.") | ||
|
|
||
| trial = mcp.study.best_trial | ||
| return f"Best trial: {trial.number} with params {json.dumps(trial.params)} and value {json.dumps(trial.value)}" | ||
| return TrialInfo( | ||
| trial_number=trial.number, | ||
| params=trial.params, | ||
| values=trial.values, | ||
| user_attrs=trial.user_attrs, | ||
| system_attrs=trial.system_attrs, | ||
| ) | ||
|
|
||
| @mcp.tool() | ||
| def best_trials() -> str: | ||
| def best_trials() -> list[TrialInfo]: | ||
| """Return trials located at the Pareto front in the study.""" | ||
| if mcp.study is None: | ||
| raise ValueError("No study has been created. Please create a study first.") | ||
| return f"Best trials: {mcp.study.best_trials}" | ||
| return [ | ||
| TrialInfo( | ||
| trial_number=trial.number, | ||
| params=trial.params, | ||
| values=trial.values, | ||
| user_attrs=trial.user_attrs, | ||
| system_attrs=trial.system_attrs, | ||
| ) | ||
| for trial in mcp.study.best_trials | ||
| ] | ||
|
|
||
| def _create_trial(trial: TrialToAdd) -> optuna.trial.FrozenTrial: | ||
| """Create a trial from the given parameters.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ dependencies = [ | |
| "plotly>=6.0.1", | ||
| "torch>=2.7.0", | ||
| "bottle>=0.13.4", | ||
| "fastmcp>=2.12.2", | ||
|
||
| ] | ||
|
|
||
| [project.scripts] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.