Skip to content

Commit

Permalink
Implement should_stop_trial_early (facebook#3139)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3139

Implement a lightweight wrapper around BaseEarlyStoppingStrategy.should_stop_trials_early that given a trial index simply tells whether the trial should be stopped or not.

Note: this wrapper will be extremely helpful for us if we want to revamp the BaseEarlyStoppingStrategy interface (which could use some serious work) without disrupting users.

Reviewed By: saitcakmak

Differential Revision: D66669283

fbshipit-source-id: 9644b98181d7d055601a43f81ffbdcda5d92ba40
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 5, 2024
1 parent 46c7e47 commit e0ff0b4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
17 changes: 16 additions & 1 deletion ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,22 @@ def should_stop_trial_early(self, trial_index: int) -> bool:
Returns:
Whether the trial should be stopped early.
"""
...
if self._early_stopping_strategy is None:
# In the future we may want to support inferring a default early stopping
# strategy
raise UnsupportedError(
"Early stopping strategy not set. Please set an early stopping "
"strategy before calling should_stop_trial_early."
)

es_response = none_throws(
self._early_stopping_strategy
).should_stop_trials_early(
trial_indices={trial_index}, experiment=self._none_throws_experiment()
)

# TODO[mpolson64]: log the returned reason for stopping the trial
return trial_index in es_response

# -------------------- Section 2.3 Marking trial status manually ----------------
def mark_trial_failed(self, trial_index: int) -> None:
Expand Down
30 changes: 30 additions & 0 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.core.trial import Trial
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError
from ax.preview.api.client import Client
from ax.preview.api.configs import (
Expand Down Expand Up @@ -737,6 +738,35 @@ def test_mark_trial_early_stopped(self) -> None:
),
)

def test_should_stop_trial_early(self) -> None:
client = Client()

client.configure_experiment(
ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo")

with self.assertRaisesRegex(
UnsupportedError, "Early stopping strategy not set"
):
client.should_stop_trial_early(trial_index=0)

client.set_early_stopping_strategy(
early_stopping_strategy=PercentileEarlyStoppingStrategy(
metric_names=["foo"]
)
)

client.get_next_trials(maximum_trials=1)
self.assertFalse(client.should_stop_trial_early(trial_index=0))


class DummyRunner(IRunner):
@override
Expand Down

0 comments on commit e0ff0b4

Please sign in to comment.