Skip to content

Commit

Permalink
Add unit test for LCBench early stopping problem; add to registry; re…
Browse files Browse the repository at this point in the history
…move metric_name (#3365)

Summary:
Pull Request resolved: #3365

- Add LCBench problems to the registry
- Remove 'metric_name' argument to early stopping problem; the unit test revealed that this argument wasn't actually working, since the outcome names on the optimization config didn't match the outcome names on the test function.
- Combined the baseline values for the transfer-learning and early-stopping problems since they are the same whenever they are both present.

Reviewed By: ltiao

Differential Revision: D69615360

fbshipit-source-id: 09106031d334bd899909208b6b9af395d70ff3cd
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 19, 2025
1 parent fe7fd5a commit 0759d38
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 101 deletions.
13 changes: 13 additions & 0 deletions ax/benchmark/problems/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
get_pytorch_cnn_torchvision_benchmark_problem,
)
from ax.benchmark.problems.runtime_funcs import int_from_params
from ax.benchmark.problems.surrogate.lcbench.early_stopping import (
get_lcbench_early_stopping_benchmark_problem,
)
from ax.benchmark.problems.surrogate.lcbench.transfer_learning import (
get_lcbench_benchmark_problem,
)
from ax.benchmark.problems.synthetic.bandit import get_bandit_problem
from ax.benchmark.problems.synthetic.discretized.mixed_integer import (
get_discrete_ackley,
Expand Down Expand Up @@ -147,6 +153,13 @@ class BenchmarkProblemRegistryEntry:
factory_fn=get_jenatton_benchmark_problem,
factory_kwargs={"num_trials": 50, "observe_noise_sd": False},
),
"LCBench:v1 Fashion-MNIST": BenchmarkProblemRegistryEntry(
get_lcbench_benchmark_problem, factory_kwargs={"dataset_name": "Fashion-MNIST"}
),
"LCBench Early Stopping Fashion-MNIST": BenchmarkProblemRegistryEntry(
get_lcbench_early_stopping_benchmark_problem,
factory_kwargs={"dataset_name": "Fashion-MNIST"},
),
"levy4": BenchmarkProblemRegistryEntry(
factory_fn=create_problem_from_botorch,
factory_kwargs={
Expand Down
8 changes: 8 additions & 0 deletions ax/benchmark/problems/surrogate/lcbench/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def url(self) -> str:

@dataclass(kw_only=True)
class LCBenchData:
"""
Args:
parameter_df: DataFrame with columns corresponding to the names of the
parameters in get_lcbench_parameter_names().
metric_series: Series of metric values with index names "trial" and "epoch".
timestamp_series: Series of timestamps with index name "trial".
"""

parameter_df: pd.DataFrame
metric_series: pd.Series
timestamp_series: pd.Series
Expand Down
49 changes: 6 additions & 43 deletions ax/benchmark/problems/surrogate/lcbench/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from ax.benchmark.problems.surrogate.lcbench.transfer_learning import DEFAULT_NUM_TRIALS
from ax.benchmark.problems.surrogate.lcbench.utils import (
BASELINE_VALUES,
DEFAULT_METRIC_NAME,
get_lcbench_log_scale_parameter_names,
get_lcbench_optimization_config,
Expand All @@ -45,43 +46,7 @@

TRegressorProtocol = TypeVar("TRegressorProtocol", bound="RegressorProtocol")

BASELINE_VALUES: dict[str, float] = {
"APSFailure": 97.75948131763847,
"Amazon_employee_access": 93.39364177908142,
"Australian": 88.1445880383116,
"Fashion-MNIST": 84.75904272864778,
"KDDCup09_appetency": 96.13544312868322,
"MiniBooNE": 85.8639428612948,
"adult": 79.50334987749676,
"airlines": 58.96099030718572,
"albert": 63.885932360810884,
"bank-marketing": 83.72755317459641,
"blood-transfusion-service-center": 62.651717620524835,
"car": 78.59464531457958,
"christine": 72.22719165860138,
"cnae-9": 92.24923138962973,
"connect-4": 63.808749677494774,
"covertype": 61.61393200315512,
"credit-g": 70.45312807563056,
"dionis": 53.71071232033245,
"fabert": 64.44304132875557,
"helena": 18.239085505279544,
"higgs": 64.74999655474926,
"jannis": 57.82155396833136,
"jasmine": 80.48475426337272,
"jungle_chess_2pcs_raw_endgame_complete": 65.58537332961572,
"kc1": 77.28692486000287,
"kr-vs-kp": 93.63368446446995,
"mfeat-factors": 94.72758417873838,
"nomao": 93.73968374826451,
"numerai28.6": 51.60281273196557,
"phoneme": 75.20979771001986,
"segment": 78.81992685291081,
"shuttle": 96.45744339531132,
"sylvine": 91.15923021902736,
"vehicle": 67.40729695042013,
"volkert": 49.204981948803855,
}

OPTIMAL_VALUES: dict[str, float] = {
"APSFailure": 98.97643280029295,
"Amazon_employee_access": 94.1208953857422,
Expand Down Expand Up @@ -256,18 +221,17 @@ def __post_init__(

def evaluate_true(self, params: Mapping[str, TParamValue]) -> torch.Tensor:
X = pd.DataFrame.from_records(data=[params])
Y = self.metric_surrogate.predict(X) # shape: (1, 50)
Y = self.metric_surrogate.predict(X=X) # shape: (1, 50)
return torch.from_numpy(Y)

def step_runtime(self, params: Mapping[str, TParamValue]) -> float:
X = pd.DataFrame.from_records(data=[params])
Y = self.runtime_surrogate.predict(X) # shape: (1,)
Y = self.runtime_surrogate.predict(X=X) # shape: (1,)
return Y.item()


def get_lcbench_early_stopping_benchmark_problem(
dataset_name: str,
metric_name: str = DEFAULT_METRIC_NAME,
num_trials: int = DEFAULT_NUM_TRIALS,
constant_step_runtime: bool = False,
noise_std: Mapping[str, float] | float = 0.0,
Expand All @@ -279,7 +243,6 @@ def get_lcbench_early_stopping_benchmark_problem(
Args:
dataset_name: Must be one of the keys of `DEFAULT_AND_OPTIMAL_VALUES`, which
correspond to the names of the datasets available in LCBench.
metric_name: The name of the metric to use for the objective.
num_trials: The number of optimization trials to run.
constant_step_runtime: Determines if the step runtime is fixed or varies
based on the hyperparameters.
Expand All @@ -296,14 +259,14 @@ def get_lcbench_early_stopping_benchmark_problem(
if dataset_name not in DATASET_NAMES:
raise UserInputError(f"`dataset_name` must be one of {sorted(DATASET_NAMES)}")

name = f"LCBench_Surrogate_{dataset_name}_{metric_name}:v1"
name = f"LCBench_Surrogate_{dataset_name}:v1"

optimal_value = OPTIMAL_VALUES[dataset_name]
baseline_value = BASELINE_VALUES[dataset_name]

search_space: SearchSpace = get_lcbench_search_space()
optimization_config: OptimizationConfig = get_lcbench_optimization_config(
metric_name=metric_name,
metric_name=DEFAULT_METRIC_NAME,
observe_noise_sd=observe_noise_sd,
use_map_metric=True,
)
Expand Down
59 changes: 2 additions & 57 deletions ax/benchmark/problems/surrogate/lcbench/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction
from ax.benchmark.problems.surrogate.lcbench.utils import (
BASELINE_VALUES,
DEFAULT_METRIC_NAME,
get_lcbench_optimization_config,
get_lcbench_search_space,
)
from ax.core.experiment import Experiment
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.registry import Cont_X_trans, Generators, Y_trans
from ax.modelbridge.torch import TorchAdapter
Expand All @@ -35,30 +33,7 @@

DEFAULT_NUM_TRIALS: int = 30

BASELINE_VALUES: dict[str, float] = {
"KDDCup09_appetency": 94.84762378096477,
"APSFailure": 97.75754021610224,
"albert": 63.893807756587876,
"Amazon_employee_access": 93.92434556024065,
"Australian": 89.35657945184583,
"Fashion-MNIST": 84.94202558279305,
"car": 80.47958436427733,
"christine": 72.27323565977512,
"cnae-9": 94.15832149950144,
"covertype": 61.552294168420595,
"dionis": 54.99212355534204,
"fabert": 64.88207128531921,
"helena": 19.156010689783603,
"higgs": 64.84690723875762,
"jannis": 57.58628096200955,
"jasmine": 80.6321652907534,
"kr-vs-kp": 94.53560263952683,
"mfeat-factors": 95.58423367904923,
"nomao": 93.51402242799601,
"shuttle": 96.43481523407816,
"sylvine": 91.91719206036713,
"volkert": 49.50686237250762,
}

DEFAULT_AND_OPTIMAL_VALUES: dict[str, tuple[float, float]] = {
"KDDCup09_appetency": (87.14437173839048, 100.41903197808242),
"APSFailure": (97.3412499690734, 98.38099041845653),
Expand All @@ -85,36 +60,6 @@
}


def get_lcbench_experiment(
metric_name: str = DEFAULT_METRIC_NAME,
observe_noise_stds: bool = False,
) -> Experiment:
"""Construct an experiment with the LCBench search space and optimization config.
Used in N5808878 to fit the initial surrogate, and may be useful for the setup
of transfer learning experiments.
Args:
observe_noise_stds: Whether or not the magnitude of the observation noise
is known.
metric_name: The name of the metric to use for the objective.
Returns:
An experiment with the LCBench search space and optimization config.
"""

search_space: SearchSpace = get_lcbench_search_space()
optimization_config: OptimizationConfig = get_lcbench_optimization_config(
metric_name=metric_name,
observe_noise_sd=observe_noise_stds,
use_map_metric=False,
)

experiment = Experiment(
search_space=search_space, optimization_config=optimization_config
)
return experiment


def get_lcbench_surrogate() -> Surrogate:
"""Construct a surrogate used to fit the LCBench data.
Expand Down
38 changes: 38 additions & 0 deletions ax/benchmark/problems/surrogate/lcbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,44 @@

DEFAULT_METRIC_NAME: str = "Train/val_accuracy"

BASELINE_VALUES: dict[str, float] = {
"APSFailure": 97.75948131763847,
"Amazon_employee_access": 93.39364177908142,
"Australian": 88.1445880383116,
"Fashion-MNIST": 84.75904272864778,
"KDDCup09_appetency": 96.13544312868322,
"MiniBooNE": 85.8639428612948,
"adult": 79.50334987749676,
"airlines": 58.96099030718572,
"albert": 63.885932360810884,
"bank-marketing": 83.72755317459641,
"blood-transfusion-service-center": 62.651717620524835,
"car": 78.59464531457958,
"christine": 72.22719165860138,
"cnae-9": 92.24923138962973,
"connect-4": 63.808749677494774,
"covertype": 61.61393200315512,
"credit-g": 70.45312807563056,
"dionis": 53.71071232033245,
"fabert": 64.44304132875557,
"helena": 18.239085505279544,
"higgs": 64.74999655474926,
"jannis": 57.82155396833136,
"jasmine": 80.48475426337272,
"jungle_chess_2pcs_raw_endgame_complete": 65.58537332961572,
"kc1": 77.28692486000287,
"kr-vs-kp": 93.63368446446995,
"mfeat-factors": 94.72758417873838,
"nomao": 93.73968374826451,
"numerai28.6": 51.60281273196557,
"phoneme": 75.20979771001986,
"segment": 78.81992685291081,
"shuttle": 96.45744339531132,
"sylvine": 91.15923021902736,
"vehicle": 67.40729695042013,
"volkert": 49.204981948803855,
}


def get_lcbench_search_space() -> SearchSpace:
"""Construct the LCBench search space."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from unittest.mock import patch

from ax.benchmark.problems.surrogate.lcbench.early_stopping import (
BASELINE_VALUES,
get_lcbench_early_stopping_benchmark_problem,
OPTIMAL_VALUES,
)
from ax.benchmark.problems.surrogate.lcbench.utils import DEFAULT_METRIC_NAME
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_mock_lcbench_data


class TestEarlyStoppingProblem(TestCase):
def test_get_lcbench_early_stopping_problem(self) -> None:
# Just test one problem for speed. We are mocking out the data load
# anyway, so there is nothing to distinguish these problems from each
# other

observe_noise_sd = True
num_trials = 4
noise_std = 1.0
seed = 27
dataset_name = "credit-g"

early_stopping_path = get_lcbench_early_stopping_benchmark_problem.__module__
with patch(
f"{early_stopping_path}.load_lcbench_data",
return_value=get_mock_lcbench_data(),
) as mock_load_lcbench_data, patch(
# Fitting a surrogate won't work with this small synthetic data
f"{early_stopping_path}._create_surrogate_regressor"
) as mock_create_surrogate_regressor:
problem = get_lcbench_early_stopping_benchmark_problem(
dataset_name=dataset_name,
observe_noise_sd=observe_noise_sd,
num_trials=num_trials,
constant_step_runtime=True,
noise_std=noise_std,
seed=seed,
)

mock_load_lcbench_data.assert_called_once()
mock_load_lcbench_data_kwargs = mock_load_lcbench_data.call_args.kwargs
self.assertEqual(mock_load_lcbench_data_kwargs["dataset_name"], dataset_name)
create_surrogate_regressor_call_args = (
mock_create_surrogate_regressor.call_args_list
)
self.assertEqual(len(create_surrogate_regressor_call_args), 2)
self.assertEqual(create_surrogate_regressor_call_args[0].kwargs["seed"], seed)
self.assertEqual(problem.noise_std, noise_std)
self.assertEqual(
problem.optimization_config.objective.metric.name, DEFAULT_METRIC_NAME
)
self.assertIsNone(problem.step_runtime_function)
self.assertEqual(problem.optimal_value, OPTIMAL_VALUES[dataset_name])
self.assertEqual(problem.baseline_value, BASELINE_VALUES[dataset_name])
11 changes: 10 additions & 1 deletion ax/benchmark/tests/problems/test_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
# pyre-strict


from unittest.mock import patch

from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.problems.registry import BENCHMARK_PROBLEM_REGISTRY, get_problem
from ax.benchmark.problems.runtime_funcs import int_from_params
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_mock_lcbench_data


class TestProblems(TestCase):
Expand All @@ -19,7 +22,13 @@ def test_load_problems(self) -> None:
if "MNIST" in name:
continue # Skip these as they cause the test to take a long time

problem = get_problem(problem_key=name)
# Avoid downloading data from the internet
with patch(
"ax.benchmark.problems.surrogate."
"lcbench.early_stopping.load_lcbench_data",
return_value=get_mock_lcbench_data(),
):
problem = get_problem(problem_key=name)
self.assertIsInstance(problem, BenchmarkProblem, msg=name)

def test_name(self) -> None:
Expand Down
Loading

0 comments on commit 0759d38

Please sign in to comment.