@@ -119,54 +119,61 @@ class OrchSearchAlgorithm(pydantic.BaseModel):
119119 ),
120120 ]
121121
122- @pydantic .model_validator (mode = "after" )
123- def map_optuna_sampler_name_to_instance (self ) -> "OrchSearchAlgorithm" :
122+ def parameters_for_ray_tune (self ) -> dict :
123+ match self .name .lower ():
124+ case "optuna" :
125+ return self ._optuna_parameters_to_ray_tune ()
126+ case "nevergrad" :
127+ return self ._nevergrad_parameters_to_ray_tune ()
128+ case _:
129+ return self .params .copy ()
130+
131+ def _optuna_parameters_to_ray_tune (self ) -> dict :
132+
133+ ray_tune_parameters = self .params .copy ()
134+ sampler_parameters = ray_tune_parameters .get ("sampler_parameters" )
135+ optuna_sampler = ray_tune_parameters .get ("sampler" )
136+ if not optuna_sampler and sampler_parameters :
137+ raise ValueError (
138+ "Optuna sampler parameters specified but no sampler specified"
139+ )
124140
125- if self .name .lower () != "optuna" :
126- return self
141+ if optuna_sampler and sampler_parameters :
142+ try :
143+ import optuna .samplers
144+
145+ sampler_cls = getattr (optuna .samplers , optuna_sampler )
146+ except (ImportError , AttributeError ) as ex :
147+ raise ImportError (
148+ f"Optuna sampler '{ optuna_sampler } ' not found in optuna.samplers. Original error: { ex } "
149+ ) from ex
150+
151+ # instantiate the sampler with any provided parameters
152+ sampler_instance = (
153+ sampler_cls (** sampler_parameters )
154+ if sampler_parameters
155+ else sampler_cls ()
156+ )
127157
128- sampler_parameters = self .params .get ("sampler_parameters" )
129- if not (optuna_sampler := self .params .get ("sampler" )):
130- if sampler_parameters :
131- raise ValueError (
132- "Optuna sampler parameters specified but no sampler specified"
133- )
134- return self
135-
136- try :
137- import optuna .samplers
138-
139- sampler_cls = getattr (optuna .samplers , optuna_sampler )
140- except (ImportError , AttributeError ) as ex :
141- raise ImportError (
142- f"Optuna sampler '{ optuna_sampler } ' not found in optuna.samplers. Original error: { ex } "
143- ) from ex
144- # instantiate the sampler with any provided parameters
145- sampler_instance = (
146- sampler_cls (** sampler_parameters ) if sampler_parameters else sampler_cls ()
147- )
148- self .params ["sampler" ] = sampler_instance
149- # delete sampler_parameters
150- self .params .pop ("sampler_parameters" , None )
158+ ray_tune_parameters ["sampler" ] = sampler_instance
159+ ray_tune_parameters .pop ("sampler_parameters" , None )
151160
152- return self
161+ return ray_tune_parameters
153162
154- @pydantic .model_validator (mode = "after" )
155- def map_nevergrad_optimizer_name_to_type (self ) -> "OrchSearchAlgorithm" :
163+ def _nevergrad_parameters_to_ray_tune (self ) -> dict :
156164
157- if self .name != "nevergrad" :
158- return self
165+ ray_tune_parameters = self .params .copy ()
159166
160167 # nevergrad wrapper requires passing the class of the optimizer in the "optimizer" param
161168 # here we have to switch from string to class
162169 # Note: The NevergradSearch interface types optimizer as optional, but it's not
163170 # We let Nevergrad handle this
164- if optimizer := self . params .get ("optimizer" ):
171+ if optimizer := ray_tune_parameters .get ("optimizer" ):
165172 import nevergrad
166173
167- self . params ["optimizer" ] = nevergrad .optimizers .registry [optimizer ]
174+ ray_tune_parameters ["optimizer" ] = nevergrad .optimizers .registry [optimizer ]
168175
169- return self
176+ return ray_tune_parameters
170177
171178
172179class OrchStopperAlgorithm (pydantic .BaseModel ):
@@ -220,32 +227,37 @@ class OrchTuneConfig(pydantic.BaseModel):
220227 model_config = ConfigDict (extra = "allow" )
221228
222229 def rayTuneConfig (self ) -> ray .tune .TuneConfig :
230+
223231 tune_options = self .model_dump ()
232+ ray_tune_parameters = self .search_alg .parameters_for_ray_tune ()
233+
224234 if self .search_alg .name .lower () == "optuna" :
225235 return create_optuna_ray_tune_config (
226236 metric = self .metric ,
227237 mode = self .mode ,
228- parameters = self . search_alg . params ,
238+ parameters = ray_tune_parameters ,
229239 tune_options = tune_options ,
230240 )
231241
242+ # 2026/03/04: at the moment only optuna supports multi-objective optimization
232243 if isinstance (self .metric , list ) or isinstance (self .mode , list ):
233244 raise Exception (
234245 f"Multi-objective optimization with { self .search_alg .name } is not supported in ado_ray_tune."
235246 )
247+
236248 if self .search_alg .name == "lhu_sampler" :
237249 return create_lhu_ray_tune_config (
238250 mode = self .mode ,
239251 metric = self .metric ,
240252 tune_options = tune_options ,
241- parameters = self . search_alg . params ,
253+ parameters = ray_tune_parameters ,
242254 )
243255 return create_general_ray_tune_config (
244256 self .search_alg .name ,
245257 mode = self .mode ,
246258 metric = self .metric ,
247259 tune_options = tune_options ,
248- parameters = self . search_alg . params ,
260+ parameters = ray_tune_parameters ,
249261 )
250262
251263
0 commit comments