Skip to content

Commit 5221d80

Browse files
authored
feat(Optuna): Allow for parsing of Choice Nodes (#290)
1 parent b680838 commit 5221d80

4 files changed

Lines changed: 367 additions & 14 deletions

File tree

src/amltk/optimization/optimizers/optuna.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def add_to_history(_, report: Trial.Report):
102102
Sorry!
103103
104104
""" # noqa: E501
105+
105106
from __future__ import annotations
106107

107108
from collections.abc import Iterable, Sequence
@@ -291,8 +292,7 @@ def ask(
291292
"""
292293
if n is not None:
293294
return (self.ask(n=None) for _ in range(n))
294-
295-
optuna_trial: optuna.Trial = self.study.ask(self.space)
295+
optuna_trial = self.space.get_trial(self.study)
296296
config = optuna_trial.params
297297
trial_number = optuna_trial.number
298298
unique_name = f"{trial_number=}"

src/amltk/pipeline/parsers/optuna.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@
9191
from __future__ import annotations
9292

9393
from collections.abc import Mapping, Sequence
94-
from typing import TYPE_CHECKING
94+
from dataclasses import dataclass, field
95+
from typing import TYPE_CHECKING, Any
9596

9697
import numpy as np
98+
import optuna
9799
from optuna.distributions import (
98100
BaseDistribution,
99101
CategoricalChoiceType,
@@ -103,17 +105,95 @@
103105
)
104106

105107
from amltk._functional import prefix_keys
108+
from amltk.pipeline.components import Choice
106109

107110
if TYPE_CHECKING:
108-
from typing import TypeAlias
109-
110111
from amltk.pipeline import Node
111112

112-
OptunaSearchSpace: TypeAlias = dict[str, BaseDistribution]
113-
114113
PAIR = 2
115114

116115

116+
@dataclass
117+
class OptunaSearchSpace:
118+
"""A class to represent an Optuna search space.
119+
120+
Wraps a dictionary of hyperparameters and their Optuna distributions.
121+
"""
122+
123+
distributions: dict[str, BaseDistribution] = field(default_factory=dict)
124+
125+
def __repr__(self) -> str:
126+
return f"OptunaSearchSpace({self.distributions})"
127+
128+
def __str__(self) -> str:
129+
return str(self.distributions)
130+
131+
@classmethod
132+
def parse(cls, *args: Any, **kwargs: Any) -> OptunaSearchSpace:
133+
"""Parse a Node into an Optuna search space."""
134+
return parser(*args, **kwargs)
135+
136+
def sample_configuration(self) -> dict[str, Any]:
137+
"""Sample a configuration from the search space using a default Optuna Study."""
138+
study = optuna.create_study()
139+
trial = self.get_trial(study)
140+
return trial.params
141+
142+
def get_trial(self, study: optuna.Study) -> optuna.Trial:
143+
"""Get a trial from a given Optuna Study using this search space."""
144+
optuna_trial: optuna.Trial
145+
if any("__choice__" in k for k in self.distributions):
146+
optuna_trial = study.ask()
147+
# do all __choice__ suggestions with suggest_categorical
148+
workspace = self.distributions.copy()
149+
filter_patterns = []
150+
for name, distribution in workspace.items():
151+
if "__choice__" in name and isinstance(
152+
distribution,
153+
CategoricalDistribution,
154+
):
155+
possible_choices = distribution.choices
156+
choice_made = optuna_trial.suggest_categorical(
157+
name,
158+
choices=possible_choices,
159+
)
160+
for c in possible_choices:
161+
if c != choice_made:
162+
# deletable options have the name of the unwanted choices
163+
filter_patterns.append(f":{c}:")
164+
# filter all parameters for the unwanted choices
165+
filtered_workspace = {
166+
k: v
167+
for k, v in workspace.items()
168+
if (
169+
("__choice__" not in k)
170+
and (
171+
not any(
172+
filter_pattern in k for filter_pattern in filter_patterns
173+
)
174+
)
175+
)
176+
}
177+
# do all remaining suggestions with the correct suggest function
178+
for name, distribution in filtered_workspace.items():
179+
match distribution:
180+
case CategoricalDistribution(choices=choices):
181+
optuna_trial.suggest_categorical(name, choices=choices)
182+
case IntDistribution(
183+
low=low,
184+
high=high,
185+
log=log,
186+
):
187+
optuna_trial.suggest_int(name, low=low, high=high, log=log)
188+
case FloatDistribution(low=low, high=high):
189+
optuna_trial.suggest_float(name, low=low, high=high)
190+
case _:
191+
raise ValueError(f"Unknown distribution: {distribution}")
192+
else:
193+
optuna_trial = study.ask(self.distributions)
194+
return optuna_trial
195+
196+
117197
def _convert_hp_to_optuna_distribution(
118198
name: str,
119199
hp: tuple | Sequence | CategoricalChoiceType | BaseDistribution,
@@ -149,7 +229,7 @@ def _convert_hp_to_optuna_distribution(
149229
raise ValueError(f"Could not parse {name} as a valid Optuna distribution.\n{hp=}")
150230

151231

152-
def _parse_space(node: Node) -> OptunaSearchSpace:
232+
def _parse_space(node: Node) -> dict[str, BaseDistribution]:
153233
match node.space:
154234
case None:
155235
space = {}
@@ -196,13 +276,21 @@ def parser(
196276
197277
delim: The delimiter to use for the names of the hyperparameters.
198278
"""
199-
if conditionals:
200-
raise NotImplementedError("Conditionals are not yet supported with Optuna.")
201-
202279
space = prefix_keys(_parse_space(node), prefix=f"{node.name}{delim}")
203280

204-
for child in node.nodes:
205-
subspace = parser(child, flat=flat, conditionals=conditionals, delim=delim)
281+
children = node.nodes
282+
283+
if isinstance(node, Choice) and any(children):
284+
name = f"{node.name}{delim}__choice__"
285+
space[name] = CategoricalDistribution([child.name for child in children])
286+
287+
for child in children:
288+
subspace = parser(
289+
child,
290+
flat=flat,
291+
conditionals=conditionals,
292+
delim=delim,
293+
).distributions
206294
if not flat:
207295
subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}")
208296

@@ -214,4 +302,4 @@ def parser(
214302
)
215303
space[name] = hp
216304

217-
return space
305+
return OptunaSearchSpace(distributions=space)

tests/optimizers/test_optimizers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from amltk.optimization import Metric, Optimizer, Trial
1212
from amltk.pipeline import Component
13+
from amltk.pipeline.components import Choice
1314
from amltk.profiling import Timer
1415

1516
if TYPE_CHECKING:
@@ -24,6 +25,10 @@ class _A:
2425
pass
2526

2627

28+
class _B:
29+
pass
30+
31+
2732
metrics = [
2833
Metric("score_bounded", minimize=False, bounds=(0, 1)),
2934
Metric("score_unbounded", minimize=False),
@@ -87,6 +92,25 @@ def opt_optuna(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
8792
)
8893

8994

95+
@case
96+
@parametrize("metric", [*metrics, metrics]) # Single obj and multi
97+
def opt_optuna_choice_hierarchical(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
98+
try:
99+
from amltk.optimization.optimizers.optuna import OptunaOptimizer
100+
except ImportError:
101+
pytest.skip("Optuna is not installed")
102+
103+
c1 = Component(_A, name="hi1", space={"a": [1, 2, 3]})
104+
c2 = Component(_B, name="hi2", space={"b": [4, 5, 6]})
105+
pipeline = Choice(c1, c2, name="hi")
106+
return OptunaOptimizer.create(
107+
space=pipeline,
108+
metrics=metric,
109+
seed=42,
110+
bucket=tmp_path,
111+
)
112+
113+
90114
@case
91115
@parametrize("metric", [*metrics]) # Single obj
92116
def opt_neps(metric: Metric, tmp_path: Path) -> NEPSOptimizer:
@@ -142,3 +166,26 @@ def test_batched_ask_generates_unique_configs(optimizer: Optimizer):
142166
batch = list(optimizer.ask(10))
143167
assert len(batch) == 10
144168
assert all_unique(batch)
169+
170+
171+
@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice")
172+
def test_optuna_choice_output(optimizer: Optimizer):
173+
trial = optimizer.ask()
174+
keys = list(trial.config.keys())
175+
assert any("__choice__" in k for k in keys), trial.config
176+
177+
178+
@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice")
179+
def test_optuna_choice_no_params_left(optimizer: Optimizer):
180+
trial = optimizer.ask()
181+
keys_without_choices = [
182+
k for k in list(trial.config.keys()) if "__choice__" not in k
183+
]
184+
for k, v in trial.config.items():
185+
if "__choice__" in k:
186+
name_without_choice = k.removesuffix("__choice__")
187+
params_for_choice = [
188+
k for k in keys_without_choices if k.startswith(name_without_choice)
189+
]
190+
# Check that only params for the chosen choice are left
191+
assert all(v in k for k in params_for_choice), params_for_choice

0 commit comments

Comments
 (0)