Skip to content

Commit e0ff0b4

Browse files
mpolson64facebook-github-bot
authored andcommitted
Implement should_stop_trial_early (#3139)
Summary: Pull Request resolved: #3139 Implement a lightweight wrapper around BaseEarlyStoppingStrategy.should_stop_trials_early that given a trial index simply tells whether the trial should be stopped or not. Note: this wrapper will be extremely helpful for us if we want to revamp the BaseEarlyStoppingStrategy interface (which could use some serious work) without disrupting users. Reviewed By: saitcakmak Differential Revision: D66669283 fbshipit-source-id: 9644b98181d7d055601a43f81ffbdcda5d92ba40
1 parent 46c7e47 commit e0ff0b4

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

ax/preview/api/client.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,22 @@ def should_stop_trial_early(self, trial_index: int) -> bool:
528528
Returns:
529529
Whether the trial should be stopped early.
530530
"""
531-
...
531+
if self._early_stopping_strategy is None:
532+
# In the future we may want to support inferring a default early stopping
533+
# strategy
534+
raise UnsupportedError(
535+
"Early stopping strategy not set. Please set an early stopping "
536+
"strategy before calling should_stop_trial_early."
537+
)
538+
539+
es_response = none_throws(
540+
self._early_stopping_strategy
541+
).should_stop_trials_early(
542+
trial_indices={trial_index}, experiment=self._none_throws_experiment()
543+
)
544+
545+
# TODO[mpolson64]: log the returned reason for stopping the trial
546+
return trial_index in es_response
532547

533548
# -------------------- Section 2.3 Marking trial status manually ----------------
534549
def mark_trial_failed(self, trial_index: int) -> None:

ax/preview/api/tests/test_client.py

+30
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ax.core.parameter_constraint import ParameterConstraint
2929
from ax.core.search_space import SearchSpace
3030
from ax.core.trial import Trial
31+
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
3132
from ax.exceptions.core import UnsupportedError
3233
from ax.preview.api.client import Client
3334
from ax.preview.api.configs import (
@@ -737,6 +738,35 @@ def test_mark_trial_early_stopped(self) -> None:
737738
),
738739
)
739740

741+
def test_should_stop_trial_early(self) -> None:
742+
client = Client()
743+
744+
client.configure_experiment(
745+
ExperimentConfig(
746+
parameters=[
747+
RangeParameterConfig(
748+
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
749+
),
750+
],
751+
name="foo",
752+
)
753+
)
754+
client.configure_optimization(objective="foo")
755+
756+
with self.assertRaisesRegex(
757+
UnsupportedError, "Early stopping strategy not set"
758+
):
759+
client.should_stop_trial_early(trial_index=0)
760+
761+
client.set_early_stopping_strategy(
762+
early_stopping_strategy=PercentileEarlyStoppingStrategy(
763+
metric_names=["foo"]
764+
)
765+
)
766+
767+
client.get_next_trials(maximum_trials=1)
768+
self.assertFalse(client.should_stop_trial_early(trial_index=0))
769+
740770

741771
class DummyRunner(IRunner):
742772
@override

0 commit comments

Comments
 (0)