Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions ax/adapter/tests/test_base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.utils.assorted import validate_input_scaling
from pandas.testing import assert_frame_equal
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws

ADAPTER__GEN_PATH: str = "ax.adapter.base.Adapter._gen"

Expand Down Expand Up @@ -908,8 +908,14 @@ def test_set_model_space(self) -> None:
.index.get_level_values("arm_name")
)
self.assertEqual(set(ood_arms), {"status_quo", "custom"})
self.assertEqual(m.model_space.parameters["x1"].lower, -5.0) # pyre-ignore[16]
self.assertEqual(m.model_space.parameters["x2"].upper, 15.0) # pyre-ignore[16]
self.assertEqual(
assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower,
-5.0,
)
self.assertEqual(
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
15.0,
)
self.assertEqual(len(m.model_space.parameter_constraints), 1)

# With expand model space, custom is not OOD, and model space is expanded
Expand All @@ -925,8 +931,14 @@ def test_set_model_space(self) -> None:
.index.get_level_values("arm_name")
)
self.assertEqual(set(ood_arms), {"status_quo"})
self.assertEqual(m.model_space.parameters["x1"].lower, -20.0)
self.assertEqual(m.model_space.parameters["x2"].upper, 18.0)
self.assertEqual(
assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower,
-20.0,
)
self.assertEqual(
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
18.0,
)
self.assertEqual(m.model_space.parameter_constraints, [])

# With fill values, SQ is also in design, and x2 is further expanded
Expand All @@ -941,7 +953,10 @@ def test_set_model_space(self) -> None:
transform_configs={"FillMissingParameters": {"fill_values": sq_vals}},
)
self.assertEqual(sum(m.training_in_design), 7)
self.assertEqual(m.model_space.parameters["x2"].upper, 20)
self.assertEqual(
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
20,
)
self.assertEqual(m.model_space.parameter_constraints, [])

# Using parameter backfill values
Expand All @@ -955,7 +970,10 @@ def test_set_model_space(self) -> None:
search_space=ss,
)
self.assertEqual(sum(m.training_in_design), 7)
self.assertEqual(m.model_space.parameters["x2"].upper, 20)
self.assertEqual(
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
20,
)
self.assertEqual(m.model_space.parameter_constraints, [])

# Check log scale expansion with OOD trial having parameter value == 0
Expand Down Expand Up @@ -992,12 +1010,12 @@ def test_set_model_space(self) -> None:

# Assert that the expanded model space did not include 0.0
self.assertEqual(
m.model_space.parameters["x1"].lower,
assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower,
0.0001,
)
# x2 model space should still be expanded
self.assertEqual(
m.model_space.parameters["x2"].upper,
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
2.0,
)

Expand Down
35 changes: 17 additions & 18 deletions ax/adapter/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from collections.abc import Iterable
from itertools import product
from typing import cast
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -70,6 +71,7 @@
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import DiagLinearOperator
from pandas import DataFrame
from pyre_extensions import assert_is_instance

# Number of in-design points created by _create_adapter_with_out_of_design_points()
_OOD_ADAPTER_IN_DESIGN_COUNT = 3
Expand All @@ -78,9 +80,8 @@
class CrossValidationTest(TestCase):
def setUp(self) -> None:
super().setUp()
# pyre-ignore [9] Pyre is too picky with union types.
parameterizations: list[TParameterization] = [
{"x": x} for x in [2.0, 2.0, 3.0, 4.0]
cast(TParameterization, {"x": x}) for x in [2.0, 2.0, 3.0, 4.0]
]
means = [[2.0, 4.0], [3.0, 5.0], [7.0, 8.0], [9.0, 10.0]]
sems = [[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]]
Expand Down Expand Up @@ -894,29 +895,27 @@ def test_efficient_loo_cv_with_fully_bayesian_model(self) -> None:
experiment = get_branin_experiment(with_batch=True, with_completed_batch=True)

# Create adapter with SaasFullyBayesianSingleTaskGP
generator = BoTorchGenerator(
surrogate=Surrogate(
surrogate_spec=SurrogateSpec(
model_configs=[
ModelConfig(
botorch_model_class=SaasFullyBayesianSingleTaskGP,
)
],
),
)
)
adapter = TorchAdapter(
experiment=experiment,
generator=BoTorchGenerator(
surrogate=Surrogate(
surrogate_spec=SurrogateSpec(
model_configs=[
ModelConfig(
botorch_model_class=SaasFullyBayesianSingleTaskGP,
)
],
),
)
),
generator=generator,
transforms=[UnitX],
)

# We need to mock the MCMC fitting to avoid running actual NUTS sampling
# which is very slow. Instead, we'll inject mock MCMC samples.
surrogate = adapter.generator.surrogate # pyre-ignore[16]
model = surrogate.model

# Verify the model is a SaasFullyBayesianSingleTaskGP
self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP)
surrogate = generator.surrogate
model = assert_is_instance(surrogate.model, SaasFullyBayesianSingleTaskGP)

# Get training data shape info
train_X = model.train_inputs[0]
Expand Down
9 changes: 3 additions & 6 deletions ax/adapter/tests/test_prediction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,11 @@ def test_predict_by_features(self) -> None:

@mock.patch("ax.adapter.random.RandomAdapter.predict")
@mock.patch("ax.adapter.random.RandomAdapter")
# pyre-fixme[3]: Return type must be annotated.
def test_predict_by_features_with_non_predicting_model(
self,
# pyre-fixme[2]: Parameter must be annotated.
adapter_mock,
# pyre-fixme[2]: Parameter must be annotated.
predict_mock,
):
adapter_mock: mock.MagicMock,
predict_mock: mock.MagicMock,
) -> None:
ax_client = _set_up_client_for_get_model_predictions_no_next_trial()
_attach_completed_trials(ax_client)

Expand Down
11 changes: 2 additions & 9 deletions ax/adapter/tests/test_random_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ax.exceptions.core import SearchSpaceExhausted
from ax.generators.random.base import RandomGenerator
from ax.generators.random.sobol import SobolGenerator
from ax.generators.types import TConfig
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_data,
Expand All @@ -45,7 +46,7 @@ def setUp(self) -> None:
]
self.search_space = SearchSpace(self.parameters, parameter_constraints)
self.experiment = Experiment(search_space=self.search_space)
self.model_gen_options = {"option": "yes"}
self.model_gen_options: TConfig = {"option": "yes"}

def test_fit(self) -> None:
adapter = RandomAdapter(experiment=self.experiment, generator=RandomGenerator())
Expand Down Expand Up @@ -79,10 +80,6 @@ def test_gen_w_constraints(self) -> None:
pending_observations={},
fixed_features=ObservationFeatures({"z": 3.0}),
optimization_config=None,
# pyre-fixme[6]: For 6th param expected `Optional[Dict[str,
# Union[None, Dict[str, typing.Any], OptimizationConfig,
# AcquisitionFunction, float, int, str]]]` but got `Dict[str,
# str]`.
model_gen_options=self.model_gen_options,
)
gen_args = mock_gen.mock_calls[0][2]
Expand Down Expand Up @@ -129,10 +126,6 @@ def test_gen_simple(self) -> None:
pending_observations={},
fixed_features=ObservationFeatures({}),
optimization_config=None,
# pyre-fixme[6]: For 6th param expected `Optional[Dict[str,
# Union[None, Dict[str, typing.Any], OptimizationConfig,
# AcquisitionFunction, float, int, str]]]` but got `Dict[str,
# str]`.
model_gen_options=self.model_gen_options,
)
gen_args = mock_gen.mock_calls[0][2]
Expand Down
2 changes: 1 addition & 1 deletion ax/adapter/tests/test_torch_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def test_convert_experiment_data(self) -> None:
ordinal_features=[2],
discrete_choices={2: list(range(0, 11))},
task_features=[2] if use_task else [],
target_values={2: 0} if use_task else {}, # pyre-ignore
target_values={2: 0.0} if use_task else {},
)
converted_datasets, ordered_outcomes, _ = adapter._convert_experiment_data(
experiment_data=experiment_data,
Expand Down
3 changes: 1 addition & 2 deletions ax/adapter/tests/test_torch_moo_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,7 @@ def test_hypervolume(self, _, cuda: bool = False) -> None:
)
for trial in exp.trials.values():
trial.mark_running(no_runner_required=True).mark_completed()
# pyre-fixme[16]: Optional type has no attribute `metrics`.
metrics_dict = exp.optimization_config.metrics
metrics_dict = none_throws(exp.optimization_config).metrics
# Objective thresholds and synthetic observations chosen to have closed-form
# hypervolumes to test.
objective_thresholds = [
Expand Down
3 changes: 0 additions & 3 deletions ax/adapter/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_extract_outcome_constraints(self) -> None:
OutcomeConstraint(metric=Metric("m1"), op=ComparisonOp.LEQ, bound=0)
]
res = extract_outcome_constraints(outcome_constraints, outcomes)
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
self.assertEqual(res[0].shape, (1, 3))
self.assertListEqual(list(res[0][0]), [1, 0, 0])
self.assertEqual(res[1][0][0], 0)
Expand Down Expand Up @@ -137,10 +136,8 @@ def test_extract_objective_thresholds(self) -> None:
outcomes=outcomes,
)
expected_obj_t_not_nan = np.array([2.0, 3.0, 4.0])
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
self.assertTrue(np.array_equal(obj_t[:3], expected_obj_t_not_nan[:3]))
self.assertTrue(np.isnan(obj_t[-1]))
# pyre-fixme[16]: Optional type has no attribute `shape`.
self.assertEqual(obj_t.shape[0], 4)

# Returns NaN for objectives without a threshold.
Expand Down
5 changes: 3 additions & 2 deletions ax/adapter/transforms/tests/test_base_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ax.core.objective import Objective
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.types import TParameterization
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment

Expand Down Expand Up @@ -66,10 +67,10 @@ def test_TransformObservations(self) -> None:
means = np.array([3.0, 4.0])
metric_signatures = ["a", "b"]
covariance = np.array([[1.0, 2.0], [3.0, 4.0]])
parameters = {"x": 1.0, "y": "cat"}
parameters: TParameterization = {"x": 1.0, "y": "cat"}
arm_name = "armmy"
observation = Observation(
features=ObservationFeatures(parameters=parameters), # pyre-ignore
features=ObservationFeatures(parameters=parameters),
data=ObservationData(
metric_signatures=metric_signatures, means=means, covariance=covariance
),
Expand Down
11 changes: 5 additions & 6 deletions ax/adapter/transforms/tests/test_cast_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from pyre_extensions import none_throws


class CastTransformTest(TestCase):
Expand Down Expand Up @@ -179,8 +180,7 @@ def test_transform_observation_features_HSS(self) -> None:
self.assertIn(p_name, obsf.parameters)
# Check that full parameterization is recorded in metadata
self.assertEqual(
# pyre-fixme[16]: Optional type has no attribute `get`.
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
self.obs_feats_hss.parameters,
)

Expand All @@ -197,7 +197,7 @@ def test_transform_observation_features_HSS(self) -> None:
self.assertIn(p_name, obsf.parameters)
# Check that full parameterization is recorded in metadata
self.assertEqual(
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
self.obs_feats_hss.parameters,
)

Expand Down Expand Up @@ -245,8 +245,7 @@ def test_untransform_observation_features_HSS(self) -> None:
},
)
self.assertEqual(
# pyre-fixme[16]: Optional type has no attribute `get`.
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
self.obs_feats_hss.parameters,
)

Expand All @@ -264,7 +263,7 @@ def test_untransform_observation_features_HSS(self) -> None:
},
)
self.assertEqual(
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
self.obs_feats_hss_2.parameters,
)

Expand Down
5 changes: 3 additions & 2 deletions ax/adapter/transforms/tests/test_choice_encode_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ def test_hss_dependents_are_preserved(self) -> None:
# x0 should be untouched because it's a fixed parameter.
self.assertIsInstance(hss.parameters["x0"], FixedParameter)
self.assertEqual(hss.parameters["x0"].parameter_type, ParameterType.BOOL)
# pyre-ignore[16] # Pyre doesn't understand fixed parameters have `.value`
self.assertEqual(hss.parameters["x0"].value, True)
self.assertEqual(
assert_is_instance(hss.parameters["x0"], FixedParameter).value, True
)
self.assertEqual(hss.parameters["x0"].dependents, {True: ["x1", "x2"]})

self.assertFalse(hss.parameters["x1"].is_hierarchical)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_num_choices(self) -> None:
"e", lower=3, upper=5, parameter_type=ParameterType.INT
),
}
search_space = SearchSpace(parameters=parameters.values()) # pyre-ignore[6]
search_space = SearchSpace(parameters=list(parameters.values()))

# Don't specify max_choices (should be set to inf)
t = IntRangeToChoice(search_space=search_space)
Expand Down
16 changes: 10 additions & 6 deletions ax/adapter/transforms/tests/test_logit_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment_with_observations
from pandas.testing import assert_frame_equal, assert_series_equal
from pyre_extensions import assert_is_instance
from scipy.special import expit, logit


Expand Down Expand Up @@ -107,16 +108,19 @@ def test_InvalidSettings(self) -> None:
def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
ss2 = self.t.transform_search_space(ss2)
# pyre-fixme[16]: `Parameter` has no attribute `lower`.
self.assertEqual(ss2.parameters["x"].lower, logit(0.9))
# pyre-fixme[16]: `Parameter` has no attribute `upper`.
self.assertEqual(ss2.parameters["x"].upper, logit(0.999))
self.assertEqual(
assert_is_instance(ss2.parameters["x"], RangeParameter).lower, logit(0.9)
)
self.assertEqual(
assert_is_instance(ss2.parameters["x"], RangeParameter).upper, logit(0.999)
)
t2 = Logit(search_space=self.search_space_with_target)
ss_target = deepcopy(self.search_space_with_target)
t2.transform_search_space(ss_target)
self.assertEqual(ss_target.parameters["x"].target_value, logit(0.123))
self.assertEqual(ss_target.parameters["x"].lower, logit(0.1))
self.assertEqual(ss_target.parameters["x"].upper, logit(0.3))
x_param = assert_is_instance(ss_target.parameters["x"], RangeParameter)
self.assertEqual(x_param.lower, logit(0.1))
self.assertEqual(x_param.upper, logit(0.3))

def test_transform_experiment_data(self) -> None:
parameterizations = [
Expand Down
7 changes: 4 additions & 3 deletions ax/adapter/transforms/tests/test_metrics_as_task_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.core.parameter import ChoiceParameter
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_search_space_for_range_values
from pyre_extensions import assert_is_instance


class MetricsAsTaskTransformTest(TestCase):
Expand Down Expand Up @@ -125,10 +126,10 @@ def test_TransformSearchSpace(self) -> None:
self.assertEqual(len(new_ss.parameters), 3)
new_param = new_ss.parameters["METRIC_TASK"]
self.assertIsInstance(new_param, ChoiceParameter)
new_param_choice = assert_is_instance(new_param, ChoiceParameter)
self.assertEqual(
# pyre-fixme[16]: `Parameter` has no attribute `values`.
new_param.values,
new_param_choice.values,
["TARGET", "metric1", "metric2"],
)
self.assertTrue(new_param.is_task) # pyre-ignore
self.assertTrue(new_param_choice.is_task)
self.assertEqual(new_param.target_value, "TARGET")
Loading