Skip to content

Commit 1f75852

Browse files
eonofreymeta-codesync[bot]
authored andcommitted
Deprecate Trial.runner (#4460)
Summary: Pull Request resolved: #4460 Upstream functionality from MultiTypeExperiment._runner_by_trial_type Full details in: T222906773 and [doc](https://docs.google.com/document/d/1u0J0VA2VzMkJO4n-J-8ngR0PNdd4CgCmePp8HUFRuTA/edit?tab=t.z51rdls4788k), but a summary is: {F1982739301} Reviewed By: mpolson64 Differential Revision: D83001393 fbshipit-source-id: e06b70c9fb085e4f2b9fd7666dc90f0017cd8c2b
1 parent 0b7c7af commit 1f75852

19 files changed

Lines changed: 123 additions & 127 deletions

ax/core/base_trial.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,6 @@ def __init__(
129129
self._run_metadata: dict[str, Any] = {}
130130
self._stop_metadata: dict[str, Any] = {}
131131

132-
self._runner: Runner | None = None
133-
134132
# Counter to maintain how many arms have been named by this BatchTrial
135133
self._num_arms_created = 0
136134

@@ -246,12 +244,30 @@ def did_not_complete(self) -> bool:
246244
@property
247245
def runner(self) -> Runner | None:
248246
"""The runner object defining how to deploy the trial."""
249-
return self._runner
247+
return self.experiment.runner_for_trial_type(self.trial_type)
250248

251249
@runner.setter
252-
@immutable_once_run
253250
def runner(self, runner: Runner | None) -> None:
254-
self._runner = runner
251+
raise UnsupportedError(
252+
"Setting runner on individual trials is no longer supported. "
253+
"Use experiment-level runners instead."
254+
)
255+
256+
@property
257+
def _runner(self) -> Runner | None:
258+
"""Private runner access is not supported."""
259+
raise UnsupportedError(
260+
"Accessing _runner on individual trials is no longer supported. "
261+
"Use trial.runner instead, which gets the runner from the experiment."
262+
)
263+
264+
@_runner.setter
265+
def _runner(self, runner: Runner | None) -> None:
266+
"""Private runner setting is not supported."""
267+
raise UnsupportedError(
268+
"Setting _runner on individual trials is no longer supported. "
269+
"Use experiment-level runners instead."
270+
)
255271

256272
@property
257273
def deployed_name(self) -> str | None:
@@ -305,13 +321,6 @@ def _add_generator_run(self, generator_run: GeneratorRun) -> None:
305321
# 4. TODO: Capture which generator run the arms we are about to add this
306322
# this trial, came from.
307323

308-
def assign_runner(self) -> BaseTrial:
309-
"""Assigns default experiment runner if trial doesn't already have one."""
310-
runner = self.experiment.runner_for_trial(self)
311-
if runner is not None:
312-
self._runner = runner.clone()
313-
return self
314-
315324
def update_run_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
316325
"""Updates the run metadata dict stored on this trial and returns the
317326
updated dict."""
@@ -338,15 +347,12 @@ def run(self) -> BaseTrial:
338347
if self.status != TrialStatus.CANDIDATE:
339348
raise ValueError("Can only run a candidate trial.")
340349

341-
# Default to experiment runner if trial doesn't have one
342-
self.assign_runner()
343-
344-
if self._runner is None:
345-
raise ValueError("No runner set on trial or experiment.")
350+
if self.runner is None:
351+
raise ValueError("No runner set on experiment.")
346352

347-
self.update_run_metadata(none_throws(self._runner).run(self))
353+
self.update_run_metadata(none_throws(self.runner).run(self))
348354

349-
if none_throws(self._runner).staging_required:
355+
if none_throws(self.runner).staging_required:
350356
self.mark_staged()
351357
else:
352358
self.mark_running()
@@ -379,11 +385,9 @@ def stop(self, new_status: TrialStatus, reason: str | None = None) -> BaseTrial:
379385
"COMPLETED, ABANDONED or EARLY_STOPPED."
380386
)
381387

382-
# Default to experiment runner if trial doesn't have one
383-
self.assign_runner()
384-
if self._runner is None:
385-
raise ValueError("No runner set on trial or experiment.")
386-
runner = none_throws(self._runner)
388+
if self.runner is None:
389+
raise ValueError("No runner set on experiment.")
390+
runner = none_throws(self.runner)
387391

388392
self._stop_metadata = runner.stop(self, reason=reason)
389393
self.mark_as(new_status)
@@ -592,7 +596,7 @@ def mark_running(
592596

593597
prev_step = (
594598
TrialStatus.STAGED
595-
if self._runner is not None and self._runner.staging_required
599+
if self.runner is not None and self.runner.staging_required
596600
else TrialStatus.CANDIDATE
597601
)
598602
prev_step_str = "staged" if prev_step == TrialStatus.STAGED else "candidate"
@@ -897,7 +901,6 @@ def _update_trial_attrs_on_clone(
897901
new_trial._run_metadata = deepcopy(self._run_metadata)
898902
new_trial._stop_metadata = deepcopy(self._stop_metadata)
899903
new_trial._num_arms_created = self._num_arms_created
900-
new_trial.runner = self._runner.clone() if self._runner else None
901904

902905
# Set status and reason accordingly.
903906
if self.status == TrialStatus.CANDIDATE:

ax/core/experiment.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
default_data_type: DataType | None = None,
101101
auxiliary_experiments_by_purpose: None
102102
| (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None,
103+
default_trial_type: str | None = None,
103104
) -> None:
104105
"""Inits Experiment.
105106
@@ -130,7 +131,7 @@ def __init__(
130131

131132
self._name = name
132133
self.description = description
133-
self.runner = runner
134+
self._runner = runner
134135
self.is_test: bool = is_test
135136

136137
self._data_by_trial: dict[int, OrderedDict[int, Data]] = {}
@@ -141,6 +142,12 @@ def __init__(
141142
self._trials: dict[int, BaseTrial] = {}
142143
self._properties: dict[str, Any] = properties or {}
143144
self._default_data_type: DataType = default_data_type or DataType.DATA
145+
146+
# Initialize trial type to runner mapping
147+
self._default_trial_type = default_trial_type
148+
self._trial_type_to_runner: dict[str | None, Runner | None] = {
149+
default_trial_type: runner
150+
}
144151
# Used to keep track of whether any trials on the experiment
145152
# specify a TTL. Since trials need to be checked for their TTL's
146153
# expiration often, having this attribute helps avoid unnecessary
@@ -391,6 +398,26 @@ def status_quo(self, status_quo: Arm | None) -> None:
391398

392399
self._status_quo = status_quo
393400

401+
@property
402+
def runner(self) -> Runner | None:
403+
"""Default runner used for trials on this experiment."""
404+
return self._runner
405+
406+
@runner.setter
407+
def runner(self, runner: Runner | None) -> None:
408+
"""Set the default runner and update trial type mapping."""
409+
self._runner = runner
410+
if runner is not None:
411+
self._trial_type_to_runner[self._default_trial_type] = runner
412+
else:
413+
self._trial_type_to_runner = {None: None}
414+
415+
@runner.deleter
416+
def runner(self) -> None:
417+
"""Delete the runner."""
418+
self._runner = None
419+
self._trial_type_to_runner = {None: None}
420+
394421
@property
395422
def parameters(self) -> dict[str, Parameter]:
396423
"""The parameters in the experiment's search space."""
@@ -1327,7 +1354,7 @@ def stop_trial_runs(
13271354
reasons = [None] * len(trials)
13281355

13291356
for trial, reason in zip(trials, reasons):
1330-
runner = self.runner_for_trial(trial=trial)
1357+
runner = self.runner_for_trial_type(trial_type=trial.trial_type)
13311358
if runner is None:
13321359
raise RunnerNotFoundError(
13331360
"Unable to stop trial runs: Runner not configured "
@@ -1336,17 +1363,6 @@ def stop_trial_runs(
13361363
runner.stop(trial=trial, reason=reason)
13371364
trial.mark_early_stopped()
13381365

1339-
def reset_runners(self, runner: Runner) -> None:
1340-
"""Replace all candidate trials runners.
1341-
1342-
Args:
1343-
runner: New runner to replace with.
1344-
"""
1345-
for trial in self._trials.values():
1346-
if trial.status == TrialStatus.CANDIDATE:
1347-
trial.runner = runner
1348-
self.runner = runner
1349-
13501366
def _attach_trial(self, trial: BaseTrial, index: int | None = None) -> int:
13511367
"""Attach a trial to this experiment.
13521368
@@ -1648,15 +1664,18 @@ def default_trial_type(self) -> str | None:
16481664
In the base experiment class this is always None. For experiments
16491665
with multiple trial types, use the MultiTypeExperiment class.
16501666
"""
1651-
return None
1667+
return self._default_trial_type
16521668

1653-
def runner_for_trial(self, trial: BaseTrial) -> Runner | None:
1654-
"""The default runner to use for a given trial.
1669+
def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
1670+
"""The default runner to use for a given trial type.
16551671
1656-
In the base experiment class, this is always the default experiment runner.
1657-
For experiments with multiple trial types, use the MultiTypeExperiment class.
1672+
Looks up the appropriate runner for this trial type in the trial_type_to_runner.
16581673
"""
1659-
return trial._runner if trial._runner else self.runner
1674+
if not self.supports_trial_type(trial_type):
1675+
raise ValueError(f"Trial type `{trial_type}` is not supported.")
1676+
if (runner := self._trial_type_to_runner.get(trial_type)) is None:
1677+
return self.runner # return the default runner
1678+
return runner
16601679

16611680
def supports_trial_type(self, trial_type: str | None) -> bool:
16621681
"""Whether this experiment allows trials of the given type.

ax/core/multi_type_experiment.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@ def __init__(
7171
default_data_type: Enum representing the data type this experiment uses.
7272
"""
7373

74-
self._default_trial_type = default_trial_type
75-
76-
# Map from trial type to default runner of that type
77-
self._trial_type_to_runner: dict[str, Runner | None] = {
78-
default_trial_type: default_runner
79-
}
80-
8174
# Specifies which trial type each metric belongs to
8275
self._metric_to_trial_type: dict[str, str] = {}
8376

@@ -99,6 +92,8 @@ def __init__(
9992
properties=properties,
10093
default_data_type=default_data_type,
10194
tracking_metrics=tracking_metrics,
95+
runner=default_runner,
96+
default_trial_type=default_trial_type,
10297
)
10398

10499
def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment":
@@ -138,6 +133,7 @@ def update_runner(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment
138133
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")
139134

140135
self._trial_type_to_runner[trial_type] = runner
136+
self._runner = runner
141137
return self
142138

143139
def add_tracking_metric(
@@ -159,7 +155,7 @@ def add_tracking_metric(
159155
raise ValueError(f"`{trial_type}` is not a supported trial type.")
160156

161157
super().add_tracking_metric(metric)
162-
self._metric_to_trial_type[metric.name] = trial_type
158+
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
163159
if canonical_name is not None:
164160
self._metric_to_canonical_name[metric.name] = canonical_name
165161
return self
@@ -307,26 +303,6 @@ def default_trial_type(self) -> str | None:
307303
"""Default trial type assigned to trials in this experiment."""
308304
return self._default_trial_type
309305

310-
def runner_for_trial(self, trial: BaseTrial) -> Runner | None:
311-
"""The default runner to use for a given trial.
312-
313-
Looks up the appropriate runner for this trial type in the trial_type_to_runner.
314-
"""
315-
return (
316-
trial._runner
317-
if trial._runner
318-
else self.runner_for_trial_type(trial_type=none_throws(trial.trial_type))
319-
)
320-
321-
def runner_for_trial_type(self, trial_type: str) -> Runner | None:
322-
"""The default runner to use for a given trial type.
323-
324-
Looks up the appropriate runner for this trial type in the trial_type_to_runner.
325-
"""
326-
if not self.supports_trial_type(trial_type):
327-
raise ValueError(f"Trial type `{trial_type}` is not supported.")
328-
return self._trial_type_to_runner[trial_type]
329-
330306
def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
331307
"""The default runner to use for a given trial type.
332308
@@ -347,11 +323,6 @@ def supports_trial_type(self, trial_type: str | None) -> bool:
347323
"""
348324
return trial_type in self._trial_type_to_runner.keys()
349325

350-
def reset_runners(self, runner: Runner) -> None:
351-
raise NotImplementedError(
352-
"MultiTypeExperiment does not support resetting all runners."
353-
)
354-
355326

356327
def filter_trials_by_type(
357328
trials: Sequence[BaseTrial], trial_type: str | None

ax/core/tests/test_batch_trial.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_UndefinedSetters(self) -> None:
9292
self.batch.status = TrialStatus.RUNNING
9393

9494
def test_BasicSetter(self) -> None:
95-
self.batch.runner = SyntheticRunner()
95+
self.experiment.runner = SyntheticRunner()
9696
self.assertIsNotNone(self.batch.runner)
9797

9898
def test_AddArm(self) -> None:
@@ -209,7 +209,7 @@ def test_BatchLifecycle(self) -> None:
209209
with patch.object(SyntheticRunner, "staging_required", staging_mock):
210210
mock_runner = SyntheticRunner()
211211
staging_mock.return_value = True
212-
self.batch.runner = mock_runner
212+
self.experiment.runner = mock_runner
213213
self.batch.run()
214214
self.assertEqual(self.batch.status, TrialStatus.STAGED)
215215
# Check that the trial statuses mapping on experiment has been updated.
@@ -231,7 +231,7 @@ def test_BatchLifecycle(self) -> None:
231231
with self.assertRaises(TrialMutationError):
232232
self.batch.add_arms_and_weights(arms=self.arms, weights=self.weights)
233233

234-
with self.assertRaises(TrialMutationError):
234+
with self.assertRaises(UnsupportedError):
235235
self.batch.runner = None
236236

237237
# Cannot run batch that was already run
@@ -317,7 +317,7 @@ def test_AbandonBatchTrial(self) -> None:
317317
self.assertEqual(self.batch.abandoned_reason, reason)
318318

319319
def test_FailedBatchTrial(self) -> None:
320-
self.batch.runner = SyntheticRunner()
320+
self.experiment.runner = SyntheticRunner()
321321
self.batch.run()
322322
self.batch.mark_failed()
323323

@@ -332,7 +332,7 @@ def test_StaleBatchTrial(self) -> None:
332332
self.assertIsNotNone(self.batch.time_completed)
333333

334334
def test_EarlyStoppedBatchTrial(self) -> None:
335-
self.batch.runner = SyntheticRunner()
335+
self.experiment.runner = SyntheticRunner()
336336
self.batch.run()
337337
self.batch.attach_batch_trial_data(
338338
raw_data={
@@ -431,7 +431,7 @@ def test_Runner(self) -> None:
431431
with self.assertRaises(ValueError):
432432
self.batch.mark_running()
433433

434-
self.batch.runner = SyntheticRunner()
434+
self.experiment.runner = SyntheticRunner()
435435
self.batch.run()
436436
self.assertEqual(self.batch.deployed_name, "test_0")
437437
self.assertNotEqual(len(self.batch.run_metadata.keys()), 0)

ax/core/tests/test_experiment.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,11 +798,10 @@ def test_ExperimentRunner(self) -> None:
798798
# pyre-fixme[6]: For 1st param expected `Optional[str]` but got `Dict[str,
799799
# bool]`.
800800
new_runner = SyntheticRunner(dummy_metadata=identifier)
801-
802-
self.experiment.reset_runners(new_runner)
803801
# Don't update trials that have been run.
804802
self.assertEqual(batch.runner, original_runner)
805803
# Update default runner
804+
self.experiment.runner = new_runner
806805
self.assertEqual(self.experiment.runner, new_runner)
807806
# Update candidate trial runners.
808807
self.assertEqual(self.experiment.trials[1].runner, new_runner)

ax/core/tests/test_multi_type_experiment.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,12 @@ def test_BadBehavior(self) -> None:
141141
batch = self.experiment.new_batch_trial()
142142
batch._trial_type = "type3" # Force override trial type
143143
with self.assertRaises(ValueError):
144-
self.experiment.runner_for_trial(batch)
144+
self.experiment.runner_for_trial_type(batch.trial_type)
145145

146146
# Try making trial with unsupported trial type
147147
with self.assertRaises(ValueError):
148148
self.experiment.new_batch_trial(trial_type="type3")
149149

150-
# Try resetting runners.
151-
with self.assertRaises(NotImplementedError):
152-
self.experiment.reset_runners(SyntheticRunner())
153-
154150
def test_setting_opt_config(self) -> None:
155151
self.assertDictEqual(
156152
self.experiment._metric_to_trial_type, {"m1": "type1", "m2": "type2"}

0 commit comments

Comments
 (0)