9191from __future__ import annotations
9292
9393from collections .abc import Mapping , Sequence
94- from typing import TYPE_CHECKING
94+ from dataclasses import dataclass , field
95+ from typing import TYPE_CHECKING , Any
9596
9697import numpy as np
98+ import optuna
9799from optuna .distributions import (
98100 BaseDistribution ,
99101 CategoricalChoiceType ,
103105)
104106
105107from amltk ._functional import prefix_keys
108+ from amltk .pipeline .components import Choice
106109
107110if TYPE_CHECKING :
108- from typing import TypeAlias
109-
110111 from amltk .pipeline import Node
111112
112- OptunaSearchSpace : TypeAlias = dict [str , BaseDistribution ]
113-
114113PAIR = 2
115114
116115
116+ @dataclass
117+ class OptunaSearchSpace :
118+ """A class to represent an Optuna search space.
119+
120+ Wraps a dictionary of hyperparameters and their Optuna distributions.
121+ """
122+
123+ distributions : dict [str , BaseDistribution ] = field (default_factory = dict )
124+
125+ def __repr__ (self ) -> str :
126+ return f"OptunaSearchSpace({ self .distributions } )"
127+
128+ def __str__ (self ) -> str :
129+ return str (self .distributions )
130+
131+ @classmethod
132+ def parse (cls , * args : Any , ** kwargs : Any ) -> OptunaSearchSpace :
133+ """Parse a Node into an Optuna search space."""
134+ return parser (* args , ** kwargs )
135+
136+ def sample_configuration (self ) -> dict [str , Any ]:
137+ """Sample a configuration from the search space using a default Optuna Study."""
138+ study = optuna .create_study ()
139+ trial = self .get_trial (study )
140+ return trial .params
141+
142+ def get_trial (self , study : optuna .Study ) -> optuna .Trial :
143+ """Get a trial from a given Optuna Study using this search space."""
144+ optuna_trial : optuna .Trial
145+ if any ("__choice__" in k for k in self .distributions ):
146+ optuna_trial = study .ask ()
147+ # do all __choice__ suggestions with suggest_categorical
148+ workspace = self .distributions .copy ()
149+ filter_patterns = []
150+ for name , distribution in workspace .items ():
151+ if "__choice__" in name and isinstance (
152+ distribution ,
153+ CategoricalDistribution ,
154+ ):
155+ possible_choices = distribution .choices
156+ choice_made = optuna_trial .suggest_categorical (
157+ name ,
158+ choices = possible_choices ,
159+ )
160+ for c in possible_choices :
161+ if c != choice_made :
162+ # deletable options have the name of the unwanted choices
163+ filter_patterns .append (f":{ c } :" )
164+ # filter all parameters for the unwanted choices
165+ filtered_workspace = {
166+ k : v
167+ for k , v in workspace .items ()
168+ if (
169+ ("__choice__" not in k )
170+ and (
171+ not any (
172+ filter_pattern in k for filter_pattern in filter_patterns
173+ )
174+ )
175+ )
176+ }
177+ # do all remaining suggestions with the correct suggest function
178+ for name , distribution in filtered_workspace .items ():
179+ match distribution :
180+ case CategoricalDistribution (choices = choices ):
181+ optuna_trial .suggest_categorical (name , choices = choices )
182+ case IntDistribution (
183+ low = low ,
184+ high = high ,
185+ log = log ,
186+ ):
187+ optuna_trial .suggest_int (name , low = low , high = high , log = log )
188+ case FloatDistribution (low = low , high = high ):
189+ optuna_trial .suggest_float (name , low = low , high = high )
190+ case _:
191+ raise ValueError (f"Unknown distribution: { distribution } " )
192+ else :
193+ optuna_trial = study .ask (self .distributions )
194+ return optuna_trial
195+
196+
117197def _convert_hp_to_optuna_distribution (
118198 name : str ,
119199 hp : tuple | Sequence | CategoricalChoiceType | BaseDistribution ,
@@ -149,7 +229,7 @@ def _convert_hp_to_optuna_distribution(
149229 raise ValueError (f"Could not parse { name } as a valid Optuna distribution.\n { hp = } " )
150230
151231
152- def _parse_space (node : Node ) -> OptunaSearchSpace :
232+ def _parse_space (node : Node ) -> dict [ str , BaseDistribution ] :
153233 match node .space :
154234 case None :
155235 space = {}
@@ -196,13 +276,21 @@ def parser(
196276
197277 delim: The delimiter to use for the names of the hyperparameters.
198278 """
199- if conditionals :
200- raise NotImplementedError ("Conditionals are not yet supported with Optuna." )
201-
202279 space = prefix_keys (_parse_space (node ), prefix = f"{ node .name } { delim } " )
203280
204- for child in node .nodes :
205- subspace = parser (child , flat = flat , conditionals = conditionals , delim = delim )
281+ children = node .nodes
282+
283+ if isinstance (node , Choice ) and any (children ):
284+ name = f"{ node .name } { delim } __choice__"
285+ space [name ] = CategoricalDistribution ([child .name for child in children ])
286+
287+ for child in children :
288+ subspace = parser (
289+ child ,
290+ flat = flat ,
291+ conditionals = conditionals ,
292+ delim = delim ,
293+ ).distributions
206294 if not flat :
207295 subspace = prefix_keys (subspace , prefix = f"{ node .name } { delim } " )
208296
@@ -214,4 +302,4 @@ def parser(
214302 )
215303 space [name ] = hp
216304
217- return space
305+ return OptunaSearchSpace ( distributions = space )
0 commit comments