Skip to content

Commit 38bf30f

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Remove pyre-fixme/pyre-ignore from ax/core, ax/adapter, ax/generators test files (#4989)
Summary: Pull Request resolved: #4989 Remove pyre-fixme and pyre-ignore type suppression comments from test files in ax/core/tests, ax/adapter/tests, ax/adapter/transforms/tests, and source file ax/adapter/transforms/one_hot.py. Uses proper type narrowing via none_throws, assert_is_instance, cast, and explicit type annotations instead of suppression comments. Key changes: - Replace `# pyre-ignore[16]` on `Parameter` attribute access with `assert_is_instance(..., RangeParameter)` / `ChoiceParameter` / `FixedParameter` - Replace `# pyre-fixme[16]` on Optional access with `none_throws(...)` - Add explicit type annotations (`TParameterization`, `TConfig`, `list[float]`, `dict[str, float | int]`) to fix type inference issues - Replace `**attrs` dict unpacking with explicit kwargs to eliminate union-type pyre errors in test_observation.py - Fix `all()` generator expression scoping bug in test_batch_trial.py (missing parentheses caused pyre-fixme[6]) - Remove unnecessary `return` statements inside `assertRaises` blocks - Add missing return type and parameter annotations on mock-decorated test methods - Refactor BoTorchGenerator construction in test_cross_validation.py to avoid pyre-ignore on `adapter.generator.surrogate` access Reviewed By: dme65 Differential Revision: D95273495 fbshipit-source-id: 5e4b0d1db817b1f95a4d691b7598358aed028da0
1 parent 6560861 commit 38bf30f

24 files changed

+224
-239
lines changed

ax/adapter/tests/test_base_adapter.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from botorch.exceptions.warnings import InputDataWarning
7777
from botorch.models.utils.assorted import validate_input_scaling
7878
from pandas.testing import assert_frame_equal
79-
from pyre_extensions import none_throws
79+
from pyre_extensions import assert_is_instance, none_throws
8080

8181
ADAPTER__GEN_PATH: str = "ax.adapter.base.Adapter._gen"
8282

@@ -908,8 +908,14 @@ def test_set_model_space(self) -> None:
908908
.index.get_level_values("arm_name")
909909
)
910910
self.assertEqual(set(ood_arms), {"status_quo", "custom"})
911-
self.assertEqual(m.model_space.parameters["x1"].lower, -5.0) # pyre-ignore[16]
912-
self.assertEqual(m.model_space.parameters["x2"].upper, 15.0) # pyre-ignore[16]
911+
self.assertEqual(
912+
assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower,
913+
-5.0,
914+
)
915+
self.assertEqual(
916+
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
917+
15.0,
918+
)
913919
self.assertEqual(len(m.model_space.parameter_constraints), 1)
914920

915921
# With expand model space, custom is not OOD, and model space is expanded
@@ -925,8 +931,14 @@ def test_set_model_space(self) -> None:
925931
.index.get_level_values("arm_name")
926932
)
927933
self.assertEqual(set(ood_arms), {"status_quo"})
928-
self.assertEqual(m.model_space.parameters["x1"].lower, -20.0)
929-
self.assertEqual(m.model_space.parameters["x2"].upper, 18.0)
934+
self.assertEqual(
935+
assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower,
936+
-20.0,
937+
)
938+
self.assertEqual(
939+
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
940+
18.0,
941+
)
930942
self.assertEqual(m.model_space.parameter_constraints, [])
931943

932944
# With fill values, SQ is also in design, and x2 is further expanded
@@ -941,7 +953,10 @@ def test_set_model_space(self) -> None:
941953
transform_configs={"FillMissingParameters": {"fill_values": sq_vals}},
942954
)
943955
self.assertEqual(sum(m.training_in_design), 7)
944-
self.assertEqual(m.model_space.parameters["x2"].upper, 20)
956+
self.assertEqual(
957+
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
958+
20,
959+
)
945960
self.assertEqual(m.model_space.parameter_constraints, [])
946961

947962
# Using parameter backfill values
@@ -955,7 +970,10 @@ def test_set_model_space(self) -> None:
955970
search_space=ss,
956971
)
957972
self.assertEqual(sum(m.training_in_design), 7)
958-
self.assertEqual(m.model_space.parameters["x2"].upper, 20)
973+
self.assertEqual(
974+
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
975+
20,
976+
)
959977
self.assertEqual(m.model_space.parameter_constraints, [])
960978

961979
# Check log scale expansion with OOD trial having parameter value == 0
@@ -992,12 +1010,12 @@ def test_set_model_space(self) -> None:
9921010

9931011
# Assert that the expanded model space did not include 0.0
9941012
self.assertEqual(
995-
m.model_space.parameters["x1"].lower,
1013+
assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower,
9961014
0.0001,
9971015
)
9981016
# x2 model space should still be expanded
9991017
self.assertEqual(
1000-
m.model_space.parameters["x2"].upper,
1018+
assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper,
10011019
2.0,
10021020
)
10031021

ax/adapter/tests/test_cross_validation.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010
from collections.abc import Iterable
1111
from itertools import product
12+
from typing import cast
1213
from unittest import mock
1314

1415
import numpy as np
@@ -70,6 +71,7 @@
7071
from gpytorch.distributions import MultivariateNormal
7172
from linear_operator.operators import DiagLinearOperator
7273
from pandas import DataFrame
74+
from pyre_extensions import assert_is_instance
7375

7476
# Number of in-design points created by _create_adapter_with_out_of_design_points()
7577
_OOD_ADAPTER_IN_DESIGN_COUNT = 3
@@ -78,9 +80,8 @@
7880
class CrossValidationTest(TestCase):
7981
def setUp(self) -> None:
8082
super().setUp()
81-
# pyre-ignore [9] Pyre is too picky with union types.
8283
parameterizations: list[TParameterization] = [
83-
{"x": x} for x in [2.0, 2.0, 3.0, 4.0]
84+
cast(TParameterization, {"x": x}) for x in [2.0, 2.0, 3.0, 4.0]
8485
]
8586
means = [[2.0, 4.0], [3.0, 5.0], [7.0, 8.0], [9.0, 10.0]]
8687
sems = [[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]]
@@ -894,29 +895,27 @@ def test_efficient_loo_cv_with_fully_bayesian_model(self) -> None:
894895
experiment = get_branin_experiment(with_batch=True, with_completed_batch=True)
895896

896897
# Create adapter with SaasFullyBayesianSingleTaskGP
898+
generator = BoTorchGenerator(
899+
surrogate=Surrogate(
900+
surrogate_spec=SurrogateSpec(
901+
model_configs=[
902+
ModelConfig(
903+
botorch_model_class=SaasFullyBayesianSingleTaskGP,
904+
)
905+
],
906+
),
907+
)
908+
)
897909
adapter = TorchAdapter(
898910
experiment=experiment,
899-
generator=BoTorchGenerator(
900-
surrogate=Surrogate(
901-
surrogate_spec=SurrogateSpec(
902-
model_configs=[
903-
ModelConfig(
904-
botorch_model_class=SaasFullyBayesianSingleTaskGP,
905-
)
906-
],
907-
),
908-
)
909-
),
911+
generator=generator,
910912
transforms=[UnitX],
911913
)
912914

913915
# We need to mock the MCMC fitting to avoid running actual NUTS sampling
914916
# which is very slow. Instead, we'll inject mock MCMC samples.
915-
surrogate = adapter.generator.surrogate # pyre-ignore[16]
916-
model = surrogate.model
917-
918-
# Verify the model is a SaasFullyBayesianSingleTaskGP
919-
self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP)
917+
surrogate = generator.surrogate
918+
model = assert_is_instance(surrogate.model, SaasFullyBayesianSingleTaskGP)
920919

921920
# Get training data shape info
922921
train_X = model.train_inputs[0]

ax/adapter/tests/test_prediction_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,11 @@ def test_predict_by_features(self) -> None:
8383

8484
@mock.patch("ax.adapter.random.RandomAdapter.predict")
8585
@mock.patch("ax.adapter.random.RandomAdapter")
86-
# pyre-fixme[3]: Return type must be annotated.
8786
def test_predict_by_features_with_non_predicting_model(
8887
self,
89-
# pyre-fixme[2]: Parameter must be annotated.
90-
adapter_mock,
91-
# pyre-fixme[2]: Parameter must be annotated.
92-
predict_mock,
93-
):
88+
adapter_mock: mock.MagicMock,
89+
predict_mock: mock.MagicMock,
90+
) -> None:
9491
ax_client = _set_up_client_for_get_model_predictions_no_next_trial()
9592
_attach_completed_trials(ax_client)
9693

ax/adapter/tests/test_random_adapter.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ax.exceptions.core import SearchSpaceExhausted
2424
from ax.generators.random.base import RandomGenerator
2525
from ax.generators.random.sobol import SobolGenerator
26+
from ax.generators.types import TConfig
2627
from ax.utils.common.testutils import TestCase
2728
from ax.utils.testing.core_stubs import (
2829
get_data,
@@ -45,7 +46,7 @@ def setUp(self) -> None:
4546
]
4647
self.search_space = SearchSpace(self.parameters, parameter_constraints)
4748
self.experiment = Experiment(search_space=self.search_space)
48-
self.model_gen_options = {"option": "yes"}
49+
self.model_gen_options: TConfig = {"option": "yes"}
4950

5051
def test_fit(self) -> None:
5152
adapter = RandomAdapter(experiment=self.experiment, generator=RandomGenerator())
@@ -79,10 +80,6 @@ def test_gen_w_constraints(self) -> None:
7980
pending_observations={},
8081
fixed_features=ObservationFeatures({"z": 3.0}),
8182
optimization_config=None,
82-
# pyre-fixme[6]: For 6th param expected `Optional[Dict[str,
83-
# Union[None, Dict[str, typing.Any], OptimizationConfig,
84-
# AcquisitionFunction, float, int, str]]]` but got `Dict[str,
85-
# str]`.
8683
model_gen_options=self.model_gen_options,
8784
)
8885
gen_args = mock_gen.mock_calls[0][2]
@@ -129,10 +126,6 @@ def test_gen_simple(self) -> None:
129126
pending_observations={},
130127
fixed_features=ObservationFeatures({}),
131128
optimization_config=None,
132-
# pyre-fixme[6]: For 6th param expected `Optional[Dict[str,
133-
# Union[None, Dict[str, typing.Any], OptimizationConfig,
134-
# AcquisitionFunction, float, int, str]]]` but got `Dict[str,
135-
# str]`.
136129
model_gen_options=self.model_gen_options,
137130
)
138131
gen_args = mock_gen.mock_calls[0][2]

ax/adapter/tests/test_torch_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def test_convert_experiment_data(self) -> None:
632632
ordinal_features=[2],
633633
discrete_choices={2: list(range(0, 11))},
634634
task_features=[2] if use_task else [],
635-
target_values={2: 0} if use_task else {}, # pyre-ignore
635+
target_values={2: 0.0} if use_task else {},
636636
)
637637
converted_datasets, ordered_outcomes, _ = adapter._convert_experiment_data(
638638
experiment_data=experiment_data,

ax/adapter/tests/test_torch_moo_adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,7 @@ def test_hypervolume(self, _, cuda: bool = False) -> None:
322322
)
323323
for trial in exp.trials.values():
324324
trial.mark_running(no_runner_required=True).mark_completed()
325-
# pyre-fixme[16]: Optional type has no attribute `metrics`.
326-
metrics_dict = exp.optimization_config.metrics
325+
metrics_dict = none_throws(exp.optimization_config).metrics
327326
# Objective thresholds and synthetic observations chosen to have closed-form
328327
# hypervolumes to test.
329328
objective_thresholds = [

ax/adapter/tests/test_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def test_extract_outcome_constraints(self) -> None:
8585
OutcomeConstraint(metric=Metric("m1"), op=ComparisonOp.LEQ, bound=0)
8686
]
8787
res = extract_outcome_constraints(outcome_constraints, outcomes)
88-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
8988
self.assertEqual(res[0].shape, (1, 3))
9089
self.assertListEqual(list(res[0][0]), [1, 0, 0])
9190
self.assertEqual(res[1][0][0], 0)
@@ -137,10 +136,8 @@ def test_extract_objective_thresholds(self) -> None:
137136
outcomes=outcomes,
138137
)
139138
expected_obj_t_not_nan = np.array([2.0, 3.0, 4.0])
140-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
141139
self.assertTrue(np.array_equal(obj_t[:3], expected_obj_t_not_nan[:3]))
142140
self.assertTrue(np.isnan(obj_t[-1]))
143-
# pyre-fixme[16]: Optional type has no attribute `shape`.
144141
self.assertEqual(obj_t.shape[0], 4)
145142

146143
# Returns NaN for objectives without a threshold.

ax/adapter/transforms/tests/test_base_transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ax.core.objective import Objective
1919
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2020
from ax.core.optimization_config import OptimizationConfig
21+
from ax.core.types import TParameterization
2122
from ax.utils.common.testutils import TestCase
2223
from ax.utils.testing.core_stubs import get_branin_experiment
2324

@@ -66,10 +67,10 @@ def test_TransformObservations(self) -> None:
6667
means = np.array([3.0, 4.0])
6768
metric_signatures = ["a", "b"]
6869
covariance = np.array([[1.0, 2.0], [3.0, 4.0]])
69-
parameters = {"x": 1.0, "y": "cat"}
70+
parameters: TParameterization = {"x": 1.0, "y": "cat"}
7071
arm_name = "armmy"
7172
observation = Observation(
72-
features=ObservationFeatures(parameters=parameters), # pyre-ignore
73+
features=ObservationFeatures(parameters=parameters),
7374
data=ObservationData(
7475
metric_signatures=metric_signatures, means=means, covariance=covariance
7576
),

ax/adapter/transforms/tests/test_cast_transform.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from pandas import DataFrame
3737
from pandas.testing import assert_frame_equal
38+
from pyre_extensions import none_throws
3839

3940

4041
class CastTransformTest(TestCase):
@@ -179,8 +180,7 @@ def test_transform_observation_features_HSS(self) -> None:
179180
self.assertIn(p_name, obsf.parameters)
180181
# Check that full parameterization is recorded in metadata
181182
self.assertEqual(
182-
# pyre-fixme[16]: Optional type has no attribute `get`.
183-
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
183+
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
184184
self.obs_feats_hss.parameters,
185185
)
186186

@@ -197,7 +197,7 @@ def test_transform_observation_features_HSS(self) -> None:
197197
self.assertIn(p_name, obsf.parameters)
198198
# Check that full parameterization is recorded in metadata
199199
self.assertEqual(
200-
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
200+
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
201201
self.obs_feats_hss.parameters,
202202
)
203203

@@ -245,8 +245,7 @@ def test_untransform_observation_features_HSS(self) -> None:
245245
},
246246
)
247247
self.assertEqual(
248-
# pyre-fixme[16]: Optional type has no attribute `get`.
249-
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
248+
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
250249
self.obs_feats_hss.parameters,
251250
)
252251

@@ -264,7 +263,7 @@ def test_untransform_observation_features_HSS(self) -> None:
264263
},
265264
)
266265
self.assertEqual(
267-
obsf.metadata.get(Keys.FULL_PARAMETERIZATION),
266+
none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION),
268267
self.obs_feats_hss_2.parameters,
269268
)
270269

ax/adapter/transforms/tests/test_choice_encode_transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,9 @@ def test_hss_dependents_are_preserved(self) -> None:
247247
# x0 should be untouched because it's a fixed parameter.
248248
self.assertIsInstance(hss.parameters["x0"], FixedParameter)
249249
self.assertEqual(hss.parameters["x0"].parameter_type, ParameterType.BOOL)
250-
# pyre-ignore[16] # Pyre doesn't understand fixed parameters have `.value`
251-
self.assertEqual(hss.parameters["x0"].value, True)
250+
self.assertEqual(
251+
assert_is_instance(hss.parameters["x0"], FixedParameter).value, True
252+
)
252253
self.assertEqual(hss.parameters["x0"].dependents, {True: ["x1", "x2"]})
253254

254255
self.assertFalse(hss.parameters["x1"].is_hierarchical)

0 commit comments

Comments
 (0)