Skip to content

Commit 43aa30e

Browse files
fix(ray_tune): do not modify search algorithm parameters in-place (#656)
Signed-off-by: Alessandro Pomponio <alessandro.pomponio1@ibm.com>
1 parent 0518e89 commit 43aa30e

File tree

1 file changed

+50
-38
lines changed
  • plugins/operators/ray_tune/ado_ray_tune

1 file changed

+50
-38
lines changed

plugins/operators/ray_tune/ado_ray_tune/config.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -119,54 +119,61 @@ class OrchSearchAlgorithm(pydantic.BaseModel):
119119
),
120120
]
121121

122-
@pydantic.model_validator(mode="after")
123-
def map_optuna_sampler_name_to_instance(self) -> "OrchSearchAlgorithm":
122+
def parameters_for_ray_tune(self) -> dict:
123+
match self.name.lower():
124+
case "optuna":
125+
return self._optuna_parameters_to_ray_tune()
126+
case "nevergrad":
127+
return self._nevergrad_parameters_to_ray_tune()
128+
case _:
129+
return self.params.copy()
130+
131+
def _optuna_parameters_to_ray_tune(self) -> dict:
132+
133+
ray_tune_parameters = self.params.copy()
134+
sampler_parameters = ray_tune_parameters.get("sampler_parameters")
135+
optuna_sampler = ray_tune_parameters.get("sampler")
136+
if not optuna_sampler and sampler_parameters:
137+
raise ValueError(
138+
"Optuna sampler parameters specified but no sampler specified"
139+
)
124140

125-
if self.name.lower() != "optuna":
126-
return self
141+
if optuna_sampler and sampler_parameters:
142+
try:
143+
import optuna.samplers
144+
145+
sampler_cls = getattr(optuna.samplers, optuna_sampler)
146+
except (ImportError, AttributeError) as ex:
147+
raise ImportError(
148+
f"Optuna sampler '{optuna_sampler}' not found in optuna.samplers. Original error: {ex}"
149+
) from ex
150+
151+
# instantiate the sampler with any provided parameters
152+
sampler_instance = (
153+
sampler_cls(**sampler_parameters)
154+
if sampler_parameters
155+
else sampler_cls()
156+
)
127157

128-
sampler_parameters = self.params.get("sampler_parameters")
129-
if not (optuna_sampler := self.params.get("sampler")):
130-
if sampler_parameters:
131-
raise ValueError(
132-
"Optuna sampler parameters specified but no sampler specified"
133-
)
134-
return self
135-
136-
try:
137-
import optuna.samplers
138-
139-
sampler_cls = getattr(optuna.samplers, optuna_sampler)
140-
except (ImportError, AttributeError) as ex:
141-
raise ImportError(
142-
f"Optuna sampler '{optuna_sampler}' not found in optuna.samplers. Original error: {ex}"
143-
) from ex
144-
# instantiate the sampler with any provided parameters
145-
sampler_instance = (
146-
sampler_cls(**sampler_parameters) if sampler_parameters else sampler_cls()
147-
)
148-
self.params["sampler"] = sampler_instance
149-
# delete sampler_parameters
150-
self.params.pop("sampler_parameters", None)
158+
ray_tune_parameters["sampler"] = sampler_instance
159+
ray_tune_parameters.pop("sampler_parameters", None)
151160

152-
return self
161+
return ray_tune_parameters
153162

154-
@pydantic.model_validator(mode="after")
155-
def map_nevergrad_optimizer_name_to_type(self) -> "OrchSearchAlgorithm":
163+
def _nevergrad_parameters_to_ray_tune(self) -> dict:
156164

157-
if self.name != "nevergrad":
158-
return self
165+
ray_tune_parameters = self.params.copy()
159166

160167
# nevergrad wrapper requires passing the class of the optimizer in the "optimizer" param
161168
# here we have to switch from string to class
162169
# Note: The NevergradSearch interface types optimizer as optional, but it's not
163170
# We let Nevergrad handle this
164-
if optimizer := self.params.get("optimizer"):
171+
if optimizer := ray_tune_parameters.get("optimizer"):
165172
import nevergrad
166173

167-
self.params["optimizer"] = nevergrad.optimizers.registry[optimizer]
174+
ray_tune_parameters["optimizer"] = nevergrad.optimizers.registry[optimizer]
168175

169-
return self
176+
return ray_tune_parameters
170177

171178

172179
class OrchStopperAlgorithm(pydantic.BaseModel):
@@ -220,32 +227,37 @@ class OrchTuneConfig(pydantic.BaseModel):
220227
model_config = ConfigDict(extra="allow")
221228

222229
def rayTuneConfig(self) -> ray.tune.TuneConfig:
230+
223231
tune_options = self.model_dump()
232+
ray_tune_parameters = self.search_alg.parameters_for_ray_tune()
233+
224234
if self.search_alg.name.lower() == "optuna":
225235
return create_optuna_ray_tune_config(
226236
metric=self.metric,
227237
mode=self.mode,
228-
parameters=self.search_alg.params,
238+
parameters=ray_tune_parameters,
229239
tune_options=tune_options,
230240
)
231241

242+
# 2026/03/04: at the moment only optuna supports multi-objective optimization
232243
if isinstance(self.metric, list) or isinstance(self.mode, list):
233244
raise Exception(
234245
f"Multi-objective optimization with {self.search_alg.name} is not supported in ado_ray_tune."
235246
)
247+
236248
if self.search_alg.name == "lhu_sampler":
237249
return create_lhu_ray_tune_config(
238250
mode=self.mode,
239251
metric=self.metric,
240252
tune_options=tune_options,
241-
parameters=self.search_alg.params,
253+
parameters=ray_tune_parameters,
242254
)
243255
return create_general_ray_tune_config(
244256
self.search_alg.name,
245257
mode=self.mode,
246258
metric=self.metric,
247259
tune_options=tune_options,
248-
parameters=self.search_alg.params,
260+
parameters=ray_tune_parameters,
249261
)
250262

251263

0 commit comments

Comments
 (0)