Skip to content

Commit 9a79809

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Remove pyre-fixme/pyre-ignore from ax/service, ax/storage, ax/utils test files (#4990)
Summary: Pull Request resolved: #4990 Remove pyre-fixme and pyre-ignore type suppression comments from test files in ax/service/tests, ax/storage/*/tests, and ax/utils/*/tests. Uses proper type narrowing via none_throws, assert_is_instance, cast, and explicit type annotations instead of suppression comments. Reviewed By: dme65 Differential Revision: D95273568 fbshipit-source-id: c8a716f5558dcd443692b9b53050d02fefa0c8e1
1 parent 38bf30f commit 9a79809

File tree

15 files changed

+308
-342
lines changed

15 files changed

+308
-342
lines changed

ax/service/tests/test_ax_client.py

Lines changed: 147 additions & 167 deletions
Large diffs are not rendered by default.

ax/service/tests/test_global_stopping.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def get_ax_client_for_branin(
4848
def evaluate(self, parameters: TParameterization) -> dict[str, tuple[float, float]]:
4949
"""Evaluates the parameters for branin experiment."""
5050
x = np.array([parameters.get(f"x{i + 1}") for i in range(2)])
51-
# pyre-fixme[7]: Expected `Dict[str, Tuple[float, float]]` but got
52-
# `Dict[str, Tuple[Union[float, ndarray], float]]`.
53-
return {"branin": (branin(x), 0.0)}
51+
return {"branin": (float(branin(x)), 0.0)}
5452

5553
def test_global_stopping_integration(self) -> None:
5654
"""
@@ -69,7 +67,6 @@ def test_global_stopping_integration(self) -> None:
6967
parameters, trial_index = ax_client.get_next_trial()
7068
ax_client.complete_trial(
7169
trial_index=trial_index,
72-
# pyre-fixme[6]: For 2nd param expected `Union[Dict[str, Union[Tuple[...
7370
raw_data=self.evaluate(parameters),
7471
)
7572

@@ -109,15 +106,13 @@ def test_min_trials(self) -> None:
109106
parameters, trial_index = ax_client.get_next_trial()
110107
ax_client.complete_trial(
111108
trial_index=trial_index,
112-
# pyre-fixme[6]: For 2nd param expected `Union[Dict[str, Union[Tuple[...
113109
raw_data=self.evaluate(parameters),
114110
)
115111

116112
# Since min_trials=3, GSS should not stop creating the 3rd iteration.
117113
parameters, trial_index = ax_client.get_next_trial()
118114
ax_client.complete_trial(
119115
trial_index=trial_index,
120-
# pyre-fixme[6]: For 2nd param expected `Union[Dict[str, Union[Tuple[Unio...
121116
raw_data=self.evaluate(parameters),
122117
)
123118
self.assertIsNotNone(parameters)

ax/service/tests/test_instantiation_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from collections.abc import Sequence
910
from typing import Any
1011

1112
from ax.core.metric import Metric
@@ -16,6 +17,7 @@
1617
ParameterType,
1718
RangeParameter,
1819
)
20+
from ax.core.types import TParamValue
1921
from ax.runners.synthetic import SyntheticRunner
2022
from ax.service.utils.instantiation import InstantiationBase
2123
from ax.utils.common.testutils import TestCase
@@ -332,7 +334,9 @@ def test_choice_with_is_sorted(self) -> None:
332334
_ = InstantiationBase.parameter_from_json(representation)
333335

334336
def test_hss(self) -> None:
335-
parameter_dicts = [
337+
parameter_dicts: list[
338+
dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]]
339+
] = [
336340
{
337341
"name": "root",
338342
"type": "fixed",
@@ -361,7 +365,6 @@ def test_hss(self) -> None:
361365
{"name": "another_int", "type": "fixed", "value": "2"},
362366
]
363367
search_space = InstantiationBase.make_search_space(
364-
# pyre-fixme[6]: For 1st param expected `List[Dict[str, Union[None, Dict[...
365368
parameters=parameter_dicts,
366369
parameter_constraints=[],
367370
)

ax/service/tests/test_interactive_loop.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
import functools
1111
import time
12+
from collections.abc import Callable, Sequence
1213
from logging import WARN
1314
from queue import Queue
1415
from threading import Event, Lock
16+
from typing import cast
1517

1618
import numpy as np
1719
from ax.adapter.registry import Generators
18-
from ax.core.types import TEvaluationOutcome, TParameterization
20+
from ax.core.types import TEvaluationOutcome, TParameterization, TParamValue
1921
from ax.generation_strategy.generation_strategy import (
2022
GenerationStep,
2123
GenerationStrategy,
@@ -41,17 +43,20 @@ def setUp(self) -> None:
4143
]
4244
)
4345
self.ax_client = AxClient(generation_strategy=generation_strategy)
44-
self.ax_client.create_experiment(
45-
name="hartmann_test_experiment",
46-
# pyre-fixme[6]
47-
parameters=[
46+
parameters = cast(
47+
list[dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]]],
48+
[
4849
{
4950
"name": f"x{i}",
5051
"type": "range",
5152
"bounds": [0.0, 1.0],
5253
}
5354
for i in range(1, 7)
5455
],
56+
)
57+
self.ax_client.create_experiment(
58+
name="hartmann_test_experiment",
59+
parameters=parameters,
5560
objectives={"hartmann6": ObjectiveProperties(minimize=True)},
5661
tracking_metric_names=["l2norm"],
5762
)
@@ -76,8 +81,10 @@ def test_interactive_loop(self) -> None:
7681
ax_client=self.ax_client,
7782
num_trials=15,
7883
candidate_queue_maxsize=3,
79-
# pyre-fixme[6]
80-
elicitation_function=self._elicit,
84+
elicitation_function=cast(
85+
Callable[[tuple[TParameterization, int]], TEvaluationOutcome],
86+
self._elicit,
87+
),
8188
)
8289

8390
self.assertTrue(optimization_completed)
@@ -94,8 +101,10 @@ def _aborted_elicit(
94101
ax_client=self.ax_client,
95102
num_trials=15,
96103
candidate_queue_maxsize=3,
97-
# pyre-fixme[6]
98-
elicitation_function=_aborted_elicit,
104+
elicitation_function=cast(
105+
Callable[[tuple[TParameterization, int]], TEvaluationOutcome],
106+
_aborted_elicit,
107+
),
99108
)
100109
self.assertFalse(optimization_completed)
101110

@@ -144,8 +153,10 @@ def _sleep_elicit(
144153
ax_client=self.ax_client,
145154
num_trials=3,
146155
candidate_queue_maxsize=3,
147-
# pyre-fixme[6]
148-
elicitation_function=_sleep_elicit,
156+
elicitation_function=cast(
157+
Callable[[tuple[TParameterization, int]], TEvaluationOutcome],
158+
_sleep_elicit,
159+
),
149160
)
150161

151162
# Assert sleep and retry warning is somewhere in the logs

ax/service/tests/test_managed_loop.py

Lines changed: 37 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,63 +6,63 @@
66

77
# pyre-strict
88

9+
from typing import Any
910
from unittest.mock import Mock, patch
1011

1112
import numpy as np
12-
import numpy.typing as npt
1313
from ax.adapter.registry import Generators
14+
from ax.core.types import TParameterization
1415
from ax.exceptions.core import UserInputError
1516
from ax.generation_strategy.generation_strategy import (
1617
GenerationStep,
1718
GenerationStrategy,
1819
)
20+
from ax.generators.random.sobol import SobolGenerator
1921
from ax.metrics.branin import branin
2022
from ax.service.managed_loop import OptimizationLoop, optimize
2123
from ax.utils.common.testutils import TestCase
2224
from ax.utils.testing.mock import mock_botorch_optimize
25+
from pyre_extensions import assert_is_instance, none_throws
2326

2427

2528
def _branin_evaluation_function(
26-
# pyre-fixme[2]: Parameter must be annotated.
27-
parameterization,
28-
weight=None, # pyre-fixme[2]: Parameter must be annotated.
29-
) -> dict[str, tuple[float | npt.NDArray, float]]:
29+
parameterization: TParameterization,
30+
weight: float | None = None,
31+
) -> dict[str, tuple[float, float]]:
3032
if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]):
3133
raise ValueError("Parametrization does not contain x1 or x2")
32-
x1, x2 = parameterization["x1"], parameterization["x2"]
34+
x1, x2 = float(parameterization["x1"]), float(parameterization["x2"])
3335
return {
34-
"branin": (branin(x1, x2), 0.0),
35-
"constrained_metric": (-branin(x1, x2), 0.0),
36+
"branin": (float(branin(x1, x2)), 0.0),
37+
"constrained_metric": (float(-branin(x1, x2)), 0.0),
3638
}
3739

3840

3941
def _branin_evaluation_function_v2(
40-
# pyre-fixme[2]: Parameter must be annotated.
41-
parameterization,
42-
weight=None, # pyre-fixme[2]: Parameter must be annotated.
43-
) -> tuple[float | npt.NDArray, float]:
42+
parameterization: TParameterization,
43+
weight: float | None = None,
44+
) -> tuple[float, float]:
4445
if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]):
4546
raise ValueError("Parametrization does not contain x1 or x2")
46-
x1, x2 = parameterization["x1"], parameterization["x2"]
47-
return (branin(x1, x2), 0.0)
47+
x1, x2 = float(parameterization["x1"]), float(parameterization["x2"])
48+
return (float(branin(x1, x2)), 0.0)
4849

4950

5051
def _branin_evaluation_function_with_unknown_sem(
51-
# pyre-fixme[2]: Parameter must be annotated.
52-
parameterization,
53-
weight=None, # pyre-fixme[2]: Parameter must be annotated.
54-
) -> tuple[float | npt.NDArray, None]:
52+
parameterization: TParameterization,
53+
weight: float | None = None,
54+
) -> tuple[float, None]:
5555
if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]):
5656
raise ValueError("Parametrization does not contain x1 or x2")
57-
x1, x2 = parameterization["x1"], parameterization["x2"]
58-
return (branin(x1, x2), None)
57+
x1, x2 = float(parameterization["x1"]), float(parameterization["x2"])
58+
return (float(branin(x1, x2)), None)
5959

6060

6161
class TestManagedLoop(TestCase):
6262
"""Check functionality of optimization loop."""
6363

6464
def test_with_evaluation_function_propagates_parameter_constraints(self) -> None:
65-
kwargs = {
65+
kwargs: dict[str, Any] = {
6666
"parameters": [
6767
{
6868
"name": "x1",
@@ -151,9 +151,7 @@ def test_branin_with_active_parameter_constraints(self) -> None:
151151
bp, _ = loop.full_run().get_best_point()
152152
self.assertIn("x1", bp)
153153
self.assertIn("x2", bp)
154-
# pyre-fixme[58]: `+` is not supported for operand types `Union[None, bool,
155-
# float, int, str]` and `Union[None, bool, float, int, str]`.
156-
self.assertLessEqual(bp["x1"] + bp["x2"], 1.0 + 1e-8)
154+
self.assertLessEqual(float(bp["x1"]) + float(bp["x2"]), 1.0 + 1e-8)
157155
with self.assertRaisesRegex(ValueError, "Optimization is complete"):
158156
loop.run_trial()
159157

@@ -241,11 +239,8 @@ def test_branin_batch(self) -> None:
241239
self.assertIn("x2", bp)
242240
assert vals is not None
243241
self.assertIn("branin", vals[0])
244-
# pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
245-
# Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
246-
self.assertIn("branin", vals[1])
247-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
248-
self.assertIn("branin", vals[1]["branin"])
242+
self.assertIn("branin", none_throws(vals[1]))
243+
self.assertIn("branin", none_throws(vals[1])["branin"])
249244
# Check that all total_trials * arms_per_trial * 2 metrics evaluations
250245
# are present in the dataframe.
251246
self.assertEqual(len(loop.experiment.fetch_data().df.index), 12)
@@ -270,11 +265,8 @@ def test_optimize(self) -> None:
270265
self.assertIn("x2", best)
271266
assert vals is not None
272267
self.assertIn("objective", vals[0])
273-
# pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
274-
# Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
275-
self.assertIn("objective", vals[1])
276-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
277-
self.assertIn("objective", vals[1]["objective"])
268+
self.assertIn("objective", none_throws(vals[1]))
269+
self.assertIn("objective", none_throws(vals[1])["objective"])
278270

279271
@patch(
280272
"ax.service.managed_loop."
@@ -301,11 +293,8 @@ def test_optimize_with_predictions(self, _) -> None:
301293
self.assertIn("x2", best)
302294
assert vals is not None
303295
self.assertIn("a", vals[0])
304-
# pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
305-
# Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
306-
self.assertIn("a", vals[1])
307-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
308-
self.assertIn("a", vals[1]["a"])
296+
self.assertIn("a", none_throws(vals[1]))
297+
self.assertIn("a", none_throws(vals[1])["a"])
309298

310299
@mock_botorch_optimize
311300
def test_optimize_unknown_sem(self) -> None:
@@ -327,11 +316,8 @@ def test_optimize_unknown_sem(self) -> None:
327316
self.assertIn("x2", best)
328317
self.assertIsNotNone(vals)
329318
self.assertIn("objective", vals[0])
330-
# pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
331-
# Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
332-
self.assertIn("objective", vals[1])
333-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
334-
self.assertIn("objective", vals[1]["objective"])
319+
self.assertIn("objective", none_throws(vals[1]))
320+
self.assertIn("objective", none_throws(vals[1])["objective"])
335321

336322
def test_optimize_propagates_random_seed(self) -> None:
337323
"""Tests optimization as a single call."""
@@ -347,8 +333,8 @@ def test_optimize_propagates_random_seed(self) -> None:
347333
total_trials=5,
348334
random_seed=12345,
349335
)
350-
# pyre-fixme[16]: Optional type has no attribute `model`.
351-
self.assertEqual(12345, model.generator.seed)
336+
generator = assert_is_instance(none_throws(model).generator, SobolGenerator)
337+
self.assertEqual(12345, generator.seed)
352338

353339
def test_optimize_search_space_exhausted(self) -> None:
354340
"""Tests optimization as a single call."""
@@ -370,11 +356,8 @@ def test_optimize_search_space_exhausted(self) -> None:
370356
self.assertIn("x2", best)
371357
self.assertIsNotNone(vals)
372358
self.assertIn("objective", vals[0])
373-
# pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
374-
# Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
375-
self.assertIn("objective", vals[1])
376-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
377-
self.assertIn("objective", vals[1]["objective"])
359+
self.assertIn("objective", none_throws(vals[1]))
360+
self.assertIn("objective", none_throws(vals[1])["objective"])
378361

379362
def test_custom_gs(self) -> None:
380363
"""Managed loop with custom generation strategy"""
@@ -432,18 +415,14 @@ def test_optimize_graceful_exit_on_exception(self) -> None:
432415
self.assertIn("x2", best)
433416
self.assertIsNotNone(vals)
434417
self.assertIn("objective", vals[0])
435-
# pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
436-
# Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
437-
self.assertIn("objective", vals[1])
438-
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
439-
self.assertIn("objective", vals[1]["objective"])
418+
self.assertIn("objective", none_throws(vals[1]))
419+
self.assertIn("objective", none_throws(vals[1])["objective"])
440420

441421
@patch(
442422
"ax.core.experiment.Experiment.new_trial",
443423
side_effect=RuntimeError("cholesky_cpu error - bad matrix"),
444424
)
445-
# pyre-fixme[3]: Return type must be annotated.
446-
def test_annotate_exception(self, _):
425+
def test_annotate_exception(self, _: Mock) -> None:
447426
strategy0 = GenerationStrategy(
448427
name="Sobol",
449428
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)],

ax/service/tests/test_report_utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import itertools
1010
import logging
1111
from collections import namedtuple
12+
from collections.abc import Callable
1213
from logging import DEBUG, INFO, WARN
14+
from typing import Any
1315
from unittest import mock
1416
from unittest.mock import patch
1517

@@ -265,8 +267,7 @@ def test_exp_to_df(self) -> None:
265267
with patch.object(Experiment, "lookup_data", lambda self: mock_results):
266268
df = exp_to_df(exp=exp)
267269
# all but two rows should have a metric value of NaN
268-
# pyre-fixme[16]: `bool` has no attribute `sum`.
269-
self.assertEqual(pd.isna(df[OBJECTIVE_NAME]).sum(), len(df.index) - 2)
270+
self.assertEqual(df[OBJECTIVE_NAME].isna().sum(), len(df.index) - 2)
270271

271272
# an experiment with more results than arms raises an error
272273
with (
@@ -369,16 +370,16 @@ def test_get_standard_plots(self) -> None:
369370
self.assertTrue(all(isinstance(plot, go.Figure) for plot in plots))
370371

371372
# Raise an exception in one plot and make sure we generate the others
372-
for plot_function, num_expected_plots in [
373-
[_get_curve_plot_dropdown, 8], # Not used
374-
[_get_objective_trace_plot, 6],
375-
[_objective_vs_true_objective_scatter, 7],
376-
[_get_objective_v_param_plots, 6],
377-
[_get_cross_validation_plots, 7],
378-
[plot_feature_importance_by_feature_plotly, 6],
379-
]:
373+
plot_test_cases: list[tuple[Callable[..., Any], int]] = [
374+
(_get_curve_plot_dropdown, 8), # Not used
375+
(_get_objective_trace_plot, 6),
376+
(_objective_vs_true_objective_scatter, 7),
377+
(_get_objective_v_param_plots, 6),
378+
(_get_cross_validation_plots, 7),
379+
(plot_feature_importance_by_feature_plotly, 6),
380+
]
381+
for plot_function, num_expected_plots in plot_test_cases:
380382
with mock.patch(
381-
# pyre-ignore
382383
f"ax.service.utils.report_utils.{plot_function.__name__}",
383384
side_effect=Exception(),
384385
):

0 commit comments

Comments
 (0)