-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit test for LCBench early stopping problem; add to registry; re…
…move metric_name (#3365) Summary: Pull Request resolved: #3365 - Add LCBench problems to the registry - Remove 'metric_name' argument to early stopping problem; the unit test revealed that this argument wasn't actually working, since the outcome names on the optimization config didn't match the outcome names on the test function. - Combined the baseline values for the transfer-learning and early-stopping problems since they are the same whenever they are both present. Reviewed By: ltiao Differential Revision: D69615360 fbshipit-source-id: 09106031d334bd899909208b6b9af395d70ff3cd
- Loading branch information
1 parent
fe7fd5a
commit 0759d38
Showing
8 changed files
with
166 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
ax/benchmark/tests/problems/surrogate/lcbench/test_early_stopping.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from unittest.mock import patch | ||
|
||
from ax.benchmark.problems.surrogate.lcbench.early_stopping import ( | ||
BASELINE_VALUES, | ||
get_lcbench_early_stopping_benchmark_problem, | ||
OPTIMAL_VALUES, | ||
) | ||
from ax.benchmark.problems.surrogate.lcbench.utils import DEFAULT_METRIC_NAME | ||
from ax.utils.common.testutils import TestCase | ||
from ax.utils.testing.benchmark_stubs import get_mock_lcbench_data | ||
|
||
|
||
class TestEarlyStoppingProblem(TestCase): | ||
def test_get_lcbench_early_stopping_problem(self) -> None: | ||
# Just test one problem for speed. We are mocking out the data load | ||
# anyway, so there is nothing to distinguish these problems from each | ||
# other | ||
|
||
observe_noise_sd = True | ||
num_trials = 4 | ||
noise_std = 1.0 | ||
seed = 27 | ||
dataset_name = "credit-g" | ||
|
||
early_stopping_path = get_lcbench_early_stopping_benchmark_problem.__module__ | ||
with patch( | ||
f"{early_stopping_path}.load_lcbench_data", | ||
return_value=get_mock_lcbench_data(), | ||
) as mock_load_lcbench_data, patch( | ||
# Fitting a surrogate won't work with this small synthetic data | ||
f"{early_stopping_path}._create_surrogate_regressor" | ||
) as mock_create_surrogate_regressor: | ||
problem = get_lcbench_early_stopping_benchmark_problem( | ||
dataset_name=dataset_name, | ||
observe_noise_sd=observe_noise_sd, | ||
num_trials=num_trials, | ||
constant_step_runtime=True, | ||
noise_std=noise_std, | ||
seed=seed, | ||
) | ||
|
||
mock_load_lcbench_data.assert_called_once() | ||
mock_load_lcbench_data_kwargs = mock_load_lcbench_data.call_args.kwargs | ||
self.assertEqual(mock_load_lcbench_data_kwargs["dataset_name"], dataset_name) | ||
create_surrogate_regressor_call_args = ( | ||
mock_create_surrogate_regressor.call_args_list | ||
) | ||
self.assertEqual(len(create_surrogate_regressor_call_args), 2) | ||
self.assertEqual(create_surrogate_regressor_call_args[0].kwargs["seed"], seed) | ||
self.assertEqual(problem.noise_std, noise_std) | ||
self.assertEqual( | ||
problem.optimization_config.objective.metric.name, DEFAULT_METRIC_NAME | ||
) | ||
self.assertIsNone(problem.step_runtime_function) | ||
self.assertEqual(problem.optimal_value, OPTIMAL_VALUES[dataset_name]) | ||
self.assertEqual(problem.baseline_value, BASELINE_VALUES[dataset_name]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.