From e0ff0b47c58c4da9e21f50881bd3b1401a4d7ec6 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Thu, 5 Dec 2024 10:42:56 -0800 Subject: [PATCH] Implement should_stop_trial_early (#3139) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3139 Implement a lightweight wrapper around BaseEarlyStoppingStrategy.should_stop_trials_early that given a trial index simply tells whether the trial should be stopped or not. Note: this wrapper will be extremely helpful for us if we want to revamp the BaseEarlyStoppingStrategy interface (which could use some serious work) without disrupting users. Reviewed By: saitcakmak Differential Revision: D66669283 fbshipit-source-id: 9644b98181d7d055601a43f81ffbdcda5d92ba40 --- ax/preview/api/client.py | 17 +++++++++++++++- ax/preview/api/tests/test_client.py | 30 +++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 4df31bafb3b..8808c13291b 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -528,7 +528,22 @@ def should_stop_trial_early(self, trial_index: int) -> bool: Returns: Whether the trial should be stopped early. """ - ... + if self._early_stopping_strategy is None: + # In the future we may want to support inferring a default early stopping + # strategy + raise UnsupportedError( + "Early stopping strategy not set. Please set an early stopping " + "strategy before calling should_stop_trial_early." + ) + + es_response = none_throws( + self._early_stopping_strategy + ).should_stop_trials_early( + trial_indices={trial_index}, experiment=self._none_throws_experiment() + ) + + # TODO[mpolson64]: log the returned reason for stopping the trial + return trial_index in es_response # -------------------- Section 2.3 Marking trial status manually ---------------- def mark_trial_failed(self, trial_index: int) -> None: diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 2f07cd5f87c..12166796cdc 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -28,6 +28,7 @@ from ax.core.parameter_constraint import ParameterConstraint from ax.core.search_space import SearchSpace from ax.core.trial import Trial +from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy from ax.exceptions.core import UnsupportedError from ax.preview.api.client import Client from ax.preview.api.configs import ( @@ -737,6 +738,35 @@ def test_mark_trial_early_stopped(self) -> None: ), ) + def test_should_stop_trial_early(self) -> None: + client = Client() + + client.configure_experiment( + ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1) + ), + ], + name="foo", + ) + ) + client.configure_optimization(objective="foo") + + with self.assertRaisesRegex( + UnsupportedError, "Early stopping strategy not set" + ): + client.should_stop_trial_early(trial_index=0) + + client.set_early_stopping_strategy( + early_stopping_strategy=PercentileEarlyStoppingStrategy( + metric_names=["foo"] + ) + ) + + client.get_next_trials(maximum_trials=1) + self.assertFalse(client.should_stop_trial_early(trial_index=0)) + class DummyRunner(IRunner): @override