Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify function for generating runtimes for benchmarking #3118

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ax/benchmark/problems/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ax.benchmark.problems.hpo.torchvision import (
get_pytorch_cnn_torchvision_benchmark_problem,
)
from ax.benchmark.problems.runtime_funcs import async_runtime_func_from_pi
from ax.benchmark.problems.runtime_funcs import int_from_trial
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem
from botorch.test_functions import synthetic
from botorch.test_functions.multi_objective import BraninCurrin
Expand Down Expand Up @@ -45,7 +45,7 @@ class BenchmarkProblemRegistryEntry:
"num_trials": 40,
"noise_std": 1.0,
"observe_noise_sd": False,
"trial_runtime_func": async_runtime_func_from_pi,
"trial_runtime_func": int_from_trial,
"name": "ackley4_async_noisy",
},
),
Expand Down
29 changes: 24 additions & 5 deletions ax/benchmark/problems/runtime_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,30 @@

# pyre-strict

import random
from collections.abc import Mapping

from ax.core.trial import Trial
from ax.core.types import TParamValue
from pyre_extensions import none_throws


def int_from_params(
params: Mapping[str, TParamValue], n_possibilities: int = 10
) -> int:
"""
Get a random int between 0 and n_possibilities - 1, using parameters for the
random seed.
"""
seed = str(tuple(sorted(params.items())))
return random.Random(seed).randrange(n_possibilities)


def async_runtime_func_from_pi(trial: Trial) -> int:
# First 49 digits of pi, not including the decimal
pi_digits_str = "3141592653589793115997963468544185161590576171875"
idx = trial.index % len(pi_digits_str)
return int(pi_digits_str[idx])
def int_from_trial(trial: Trial, n_possibilities: int = 10) -> int:
"""
Get a random int between 0 and n_possibilities - 1, using the parameters of
the trial's first arm for the random seed.
"""
return int_from_params(
params=none_throws(trial.arms)[0].parameters, n_possibilities=n_possibilities
)
15 changes: 15 additions & 0 deletions ax/benchmark/tests/problems/test_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

# pyre-strict

from unittest.mock import MagicMock

from ax.benchmark.problems.registry import BENCHMARK_PROBLEM_REGISTRY, get_problem
from ax.benchmark.problems.runtime_funcs import int_from_params, int_from_trial
from ax.core.arm import Arm
from ax.core.trial import Trial
from ax.utils.common.testutils import TestCase


Expand Down Expand Up @@ -57,3 +62,13 @@ def test_registry_kwargs_not_mutated(self) -> None:
)
problem = get_problem(problem_key="jenatton")
self.assertEqual(problem.num_trials, 50)

def test_runtime_funcs(self) -> None:
parameters = {"x0": 0.5, "x1": -3, "x2": "-4", "x3": False, "x4": None}
result = int_from_params(params=parameters)
expected = 3
self.assertEqual(result, expected)
arm = Arm(name="0_0", parameters=parameters)
trial = MagicMock(spec=Trial)
trial.arms = [arm]
self.assertEqual(int_from_trial(trial=trial), expected)