Skip to content

Commit 2691d91

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Remove pyre-fixme/pyre-ignore from ax/service/ and ax/api/ (#4979)
Summary: Pull Request resolved: #4979 Remove ~43 pyre-fixme/pyre-ignore suppression comments from 12 files in ax/service/ (source) and ax/api/ (source + tests): - Use `cast()` for API TParameterization/TParamValue type mismatches - Use `assert_is_instance()` and `assert_is_instance_optional()` for config values - Add `partial[Any]` annotation for `round_floats_for_logging` - Convert bare attribute declarations to `property abstractmethod` - Use `npt.NDArray` for numpy array types - Add `Literal` types for test parameterization Reviewed By: dme65 Differential Revision: D95264965 fbshipit-source-id: 3291e2d9da3b9cf82bd87e626135b6d011690495
1 parent 3d852de commit 2691d91

File tree

11 files changed

+120
-117
lines changed

11 files changed

+120
-117
lines changed

ax/api/client.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
from collections.abc import Iterable, Sequence
1010
from logging import Logger
11-
from typing import Any, Literal, Self
11+
from typing import Any, cast, Literal, Self
1212

1313
import numpy as np
1414
import pandas as pd
@@ -39,6 +39,7 @@
3939
from ax.core.runner import Runner
4040
from ax.core.trial import Trial
4141
from ax.core.trial_status import TrialStatus # Used as a return type
42+
from ax.core.types import TParameterization as CoreTParameterization
4243
from ax.early_stopping.strategies import (
4344
BaseEarlyStoppingStrategy,
4445
PercentileEarlyStoppingStrategy,
@@ -183,8 +184,7 @@ def configure_optimization(
183184
pruning_target_arm: Arm | None = None
184185
if pruning_target_parameterization is not None:
185186
self._experiment.search_space.validate_membership(
186-
# pyre-fixme[6]: Core Ax TParameterization is dict not Mapping
187-
parameters=pruning_target_parameterization
187+
parameters=cast(CoreTParameterization, pruning_target_parameterization)
188188
)
189189
pruning_target_arm = Arm(
190190
parameters=pruning_target_parameterization, name="pruning_target"
@@ -442,9 +442,9 @@ def get_next_trials(
442442
experiment=self._experiment,
443443
n=1,
444444
fixed_features=(
445-
# pyre-fixme[6]: Type narrowing broken because core Ax
446-
# TParameterization is dict not Mapping
447-
ObservationFeatures(parameters=fixed_parameters)
445+
ObservationFeatures(
446+
parameters=cast(CoreTParameterization, fixed_parameters)
447+
)
448448
if fixed_parameters is not None
449449
else None
450450
),
@@ -483,9 +483,10 @@ def get_next_trials(
483483
experiment=self._experiment, trials=trials
484484
)
485485

486-
# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
487-
# None, but we do not allow this in the API.
488-
return {trial.index: none_throws(trial.arm).parameters for trial in trials}
486+
return {
487+
trial.index: cast(TParameterization, none_throws(trial.arm).parameters)
488+
for trial in trials
489+
}
489490

490491
def complete_trial(
491492
self,
@@ -573,9 +574,7 @@ def attach_trial(
573574
The index of the attached trial.
574575
"""
575576
_, trial_index = self._experiment.attach_trial(
576-
# pyre-fixme[6]: Type narrowing broken because core Ax TParameterization
577-
# is dict not Mapping
578-
parameterizations=[parameters],
577+
parameterizations=[cast(CoreTParameterization, parameters)],
579578
arm_names=[arm_name] if arm_name else None,
580579
)
581580

@@ -888,13 +887,14 @@ def get_best_parameterization(
888887
)
889888
)
890889

891-
# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
892-
# None but we do not allow this in the API.
893-
return BestPointMixin._to_best_point_tuple(
894-
experiment=self._experiment,
895-
trial_index=trial_index,
896-
parameterization=parameterization,
897-
model_prediction=model_prediction,
890+
return cast(
891+
tuple[TParameterization, TOutcome, int, str],
892+
BestPointMixin._to_best_point_tuple(
893+
experiment=self._experiment,
894+
trial_index=trial_index,
895+
parameterization=parameterization,
896+
model_prediction=model_prediction,
897+
),
898898
)
899899

900900
def get_pareto_frontier(
@@ -945,14 +945,15 @@ def get_pareto_frontier(
945945
use_model_predictions=use_model_predictions,
946946
)
947947

948-
# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
949-
# None but we do not allow this in the API.
950948
return [
951-
BestPointMixin._to_best_point_tuple(
952-
experiment=self._experiment,
953-
trial_index=trial_index,
954-
parameterization=parameterization,
955-
model_prediction=model_prediction,
949+
cast(
950+
tuple[TParameterization, TOutcome, int, str],
951+
BestPointMixin._to_best_point_tuple(
952+
experiment=self._experiment,
953+
trial_index=trial_index,
954+
parameterization=parameterization,
955+
model_prediction=model_prediction,
956+
),
956957
)
957958
for trial_index, (parameterization, model_prediction) in frontier.items()
958959
]
@@ -978,9 +979,9 @@ def predict(
978979
try:
979980
mean, covariance = none_throws(self._generation_strategy.adapter).predict(
980981
observation_features=[
981-
# pyre-fixme[6]: Core Ax allows users to specify TParameterization
982-
# values as None but we do not allow this in the API.
983-
ObservationFeatures(parameters=parameters)
982+
ObservationFeatures(
983+
parameters=cast(CoreTParameterization, parameters)
984+
)
984985
for parameters in points
985986
]
986987
)

ax/api/protocols/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from abc import ABC, abstractmethod
1111
from collections import defaultdict
1212
from collections.abc import Iterable, Mapping
13-
from typing import Any
13+
from typing import Any, cast
1414

1515
import pandas as pd
1616
from ax.api.types import TParameterization
@@ -115,10 +115,10 @@ def run(self, trial: BaseTrial) -> dict[str, Any]:
115115
"""
116116
metadata = self.run_trial(
117117
trial_index=trial.index,
118-
# pyre-ignore[6] Arms in core Ax may have None in their parameters
119-
parameterization=none_throws(
120-
assert_is_instance(trial, Trial).arm
121-
).parameters,
118+
parameterization=cast(
119+
TParameterization,
120+
none_throws(assert_is_instance(trial, Trial).arm).parameters,
121+
),
122122
)
123123

124124
# Runtime validate metadata is JSON serializable to avoid issues when

ax/api/tests/test_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,10 +1794,11 @@ def test_get_next_trials_with_derived_parameters(self) -> None:
17941794
self.assertIn("x2", trial_params)
17951795
self.assertIn("x3", trial_params)
17961796
# Verify derived parameter is correctly computed
1797-
# pyre-fixme[58]: Arithmetic operations on TParameterValue
1798-
expected_x3 = 1.0 - trial_params["x1"] - trial_params["x2"]
1799-
# pyre-fixme[6]: Type mismatch on assertAlmostEqual
1800-
self.assertAlmostEqual(trial_params["x3"], expected_x3, places=6)
1797+
x1_val = assert_is_instance(trial_params["x1"], float)
1798+
x2_val = assert_is_instance(trial_params["x2"], float)
1799+
x3_val = assert_is_instance(trial_params["x3"], float)
1800+
expected_x3 = 1.0 - x1_val - x2_val
1801+
self.assertAlmostEqual(x3_val, expected_x3, places=6)
18011802

18021803
def test_complete_trial_with_derived_parameters(self) -> None:
18031804
# Setup: Configure experiment with derived parameter and generate trial

ax/api/utils/instantiation/from_config.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# pyre-strict
77

88

9+
from typing import cast
10+
911
import numpy as np
1012
from ax.api.configs import (
1113
ChoiceParameterConfig,
@@ -20,6 +22,7 @@
2022
ParameterType as CoreParameterType,
2123
RangeParameter,
2224
)
25+
from ax.core.types import TParamValue
2326
from ax.exceptions.core import UserInputError
2427

2528

@@ -79,21 +82,21 @@ def parameter_from_config(
7982
name=config.name,
8083
parameter_type=_parameter_type_converter(config.parameter_type),
8184
value=config.values[0],
82-
# pyre-fixme[6] Variance issue caused by FixedParameter.dependents
83-
# using List instead of immutable container type.
84-
dependents=config.dependent_parameters,
85+
dependents=cast(
86+
dict[TParamValue, list[str]] | None,
87+
config.dependent_parameters,
88+
),
8589
)
8690

8791
return ChoiceParameter(
8892
name=config.name,
8993
parameter_type=_parameter_type_converter(config.parameter_type),
90-
# pyre-fixme[6] Variance issue caused by ChoiceParameter.value using List
91-
# instead of immutable container type.
92-
values=config.values,
94+
values=cast(list[TParamValue], config.values),
9395
is_ordered=config.is_ordered,
94-
# pyre-fixme[6] Variance issue caused by ChoiceParameter.dependents using
95-
# List instead of immutable container type.
96-
dependents=config.dependent_parameters,
96+
dependents=cast(
97+
dict[TParamValue, list[str]] | None,
98+
config.dependent_parameters,
99+
),
97100
sort_values=config.parameter_type != "str", # Matches default behavior.
98101
)
99102

ax/api/utils/tests/test_generation_strategy_dispatch.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
from itertools import product
11-
from typing import Any
11+
from typing import Any, Literal
1212

1313
import torch
1414
from ax.adapter.registry import Generators
@@ -234,13 +234,9 @@ def test_choose_gs_single_sobol_initialization(self) -> None:
234234
self.assertEqual(mbm_node.name, "MBM")
235235

236236
def test_gs_simplify_parameter_changes(self) -> None:
237-
for simplify, method in product((True, False), ("fast", "quality")):
237+
methods: list[Literal["fast", "quality"]] = ["fast", "quality"]
238+
for simplify, method in product((True, False), methods):
238239
struct = GenerationStrategyDispatchStruct(
239-
# pyre-fixme [6]: In call
240-
# `GenerationStrategyDispatchStruct.__init__`, for argument
241-
# `method`, expected `Union[typing_extensions.Literal['fast'],
242-
# typing_extensions.Literal['quality'],
243-
# typing_extensions.Literal['random_search']]` but got `str`
244240
method=method,
245241
simplify_parameter_changes=simplify,
246242
)
@@ -323,7 +319,13 @@ def test_choose_gs_with_custom_botorch_acqf_class(self) -> None:
323319
"""Test that custom botorch_acqf_class is properly passed to generator kwargs
324320
and appended to the node name. Tests both fast and custom methods.
325321
"""
326-
for method, model_config, expected_name in [
322+
test_cases: list[
323+
tuple[
324+
Literal["quality", "fast", "random_search", "custom"],
325+
ModelConfig | None,
326+
str,
327+
]
328+
] = [
327329
("fast", None, "Sobol+MBM:fast+qLogNoisyExpectedImprovement"),
328330
(
329331
"custom",
@@ -333,10 +335,11 @@ def test_choose_gs_with_custom_botorch_acqf_class(self) -> None:
333335
),
334336
"Sobol+MBM:MAPSAAS+qLogNoisyExpectedImprovement",
335337
),
336-
]:
338+
]
339+
for method, model_config, expected_name in test_cases:
337340
with self.subTest(method=method):
338341
struct = GenerationStrategyDispatchStruct(
339-
method=method, # pyre-ignore [6]
342+
method=method,
340343
initialization_budget=3,
341344
initialize_with_center=False,
342345
)

ax/service/ax_client.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@
8787

8888
ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES: int = 6
8989

90-
# pyre-fixme[5]: Global expression must be annotated.
91-
round_floats_for_logging = partial(
90+
round_floats_for_logging: partial[Any] = partial(
9291
_round_floats_for_logging,
9392
decimal_places=ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES,
9493
)
@@ -1303,8 +1302,7 @@ def save_to_json_file(self, filepath: str = "ax_client_snapshot.json") -> None:
13031302
def load_from_json_file(
13041303
cls: type[AxClientSubclass],
13051304
filepath: str = "ax_client_snapshot.json",
1306-
# pyre-fixme[2]: Parameter must be annotated.
1307-
**kwargs,
1305+
**kwargs: Any,
13081306
) -> AxClientSubclass:
13091307
"""Restore an `AxClient` and its state from a JSON-serialized snapshot,
13101308
residing in a .json file by the given path.
@@ -1315,13 +1313,10 @@ def load_from_json_file(
13151313

13161314
def to_json_snapshot(
13171315
self,
1318-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
1319-
# `typing.Type` to avoid runtime subscripting errors.
1320-
encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] | None = None,
1321-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
1322-
# `typing.Type` to avoid runtime subscripting errors.
1323-
class_encoder_registry: None
1324-
| (dict[type, Callable[[Any], dict[str, Any]]]) = None,
1316+
encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]]
1317+
| None = None,
1318+
class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]]
1319+
| None = None,
13251320
) -> dict[str, Any]:
13261321
"""Serialize this `AxClient` to JSON to be able to interrupt and restart
13271322
optimization and save it to file by the provided path.
@@ -1357,8 +1352,7 @@ def from_json_snapshot(
13571352
decoder_registry: TDecoderRegistry | None = None,
13581353
class_decoder_registry: None
13591354
| (dict[str, Callable[[dict[str, Any]], Any]]) = None,
1360-
# pyre-fixme[2]: Parameter must be annotated.
1361-
**kwargs,
1355+
**kwargs: Any,
13621356
) -> AxClientSubclass:
13631357
"""Recreate an `AxClient` from a JSON snapshot."""
13641358
if decoder_registry is None:

ax/service/interactive_loop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def interactive_optimize(
2828
candidate_queue_maxsize: int,
2929
candidate_generator_function: Callable[..., None],
3030
data_attacher_function: Callable[..., None],
31-
# pyre-ignore[2]: Missing parameter annotation
3231
elicitation_function: Callable[..., Any],
3332
candidate_generator_kwargs: dict[str, Any] | None = None,
3433
data_attacher_kwargs: dict[str, Any] | None = None,

ax/service/managed_loop.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import inspect
1212
import logging
1313
import warnings
14-
from collections.abc import Iterable
14+
from collections.abc import Callable, Iterable
15+
from typing import cast
1516

1617
# Manual import to avoid strange error, see Diff for details.
1718
import ax.generation_strategy.generation_node_input_constructors # noqa
@@ -73,12 +74,13 @@ def __init__(
7374
)
7475
self.experiment = experiment
7576
if generation_strategy is None:
76-
# pyre-fixme[4]: Attribute must be annotated.
77-
self.generation_strategy = choose_generation_strategy_legacy(
78-
search_space=experiment.search_space,
79-
use_batch_trials=self.arms_per_trial > 1,
80-
random_seed=self.random_seed,
81-
experiment=experiment,
77+
self.generation_strategy: GenerationStrategy = (
78+
choose_generation_strategy_legacy(
79+
search_space=experiment.search_space,
80+
use_batch_trials=self.arms_per_trial > 1,
81+
random_seed=self.random_seed,
82+
experiment=experiment,
83+
)
8284
)
8385
else:
8486
self.generation_strategy = generation_strategy
@@ -147,11 +149,15 @@ def _call_evaluation_function(
147149
signature = inspect.signature(self.evaluation_function)
148150
num_evaluation_function_params = len(signature.parameters.items())
149151
if num_evaluation_function_params == 1:
150-
# pyre-ignore [20]: Can't run instance checks on subscripted generics.
151-
evaluation = self.evaluation_function(parameterization)
152+
evaluation = cast(
153+
Callable[[TParameterization], TEvaluationOutcome],
154+
self.evaluation_function,
155+
)(parameterization)
152156
elif num_evaluation_function_params == 2:
153-
# pyre-ignore [19]: Can't run instance checks on subscripted generics.
154-
evaluation = self.evaluation_function(parameterization, weight)
157+
evaluation = cast(
158+
Callable[[TParameterization, float | None], TEvaluationOutcome],
159+
self.evaluation_function,
160+
)(parameterization, weight)
155161
else:
156162
raise UserInputError(
157163
"Evaluation function must take either one parameter "

0 commit comments

Comments
 (0)