Skip to content

Commit 4059782

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Remove pyre-fixme/pyre-ignore from ax/core/ source files (#4976)
Summary: Pull Request resolved: #4976 Remove ~29 pyre-fixme/pyre-ignore suppression comments from 11 source files in ax/core/ by applying proper type fixes: - Use `none_throws()` for Optional unwrapping - Use `cast()` for type narrowing - Add proper type annotations (e.g., `partial[Any]`, `type[Any]`) - Refactor pandas `itertuples()` access to use `Any`-typed row variable - Make `Parameter.clone()` abstract with `abstractmethod` - Use `float()` casts for numeric comparisons - Remove explicit `: int` from enum members (pyre-fixme[35]) 24 genuinely unfixable suppressions remain (documented): - `copy_doc` with `property` interaction - Property setter decorator type inference - `np.floating`/`np.integer` generic params (runtime isinstance limitation) - Intentional inconsistent overrides in OptimizationConfig subclasses - Abstract class instantiation via `type(self)` Reviewed By: dme65 Differential Revision: D95264859 fbshipit-source-id: 77326c274feb662a6e2c57be9347ec09ae005c0e
1 parent 2691d91 commit 4059782

File tree

11 files changed

+69
-99
lines changed

11 files changed

+69
-99
lines changed

ax/core/base_trial.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,12 @@
3838
MAX_ABANDONED_REASON_LENGTH = 1000
3939

4040

41-
def immutable_once_run(func: Callable) -> Callable:
41+
def immutable_once_run(func: Callable[..., Any]) -> Callable[..., Any]:
4242
"""Decorator for methods that should throw Error when
4343
trial is running or has ever run and immutable.
4444
"""
4545

46-
# no type annotation for now; breaks sphinx-autodoc-typehints
47-
# pyre-fixme[3]: Return type must be annotated.
48-
# pyre-fixme[2]: Parameter must be annotated.
49-
def _immutable_once_run(self, *args, **kwargs):
46+
def _immutable_once_run(self: Any, *args: Any, **kwargs: Any) -> Any:
5047
if self._status != TrialStatus.CANDIDATE:
5148
raise TrialMutationError(
5249
"Cannot modify a trial that is running or has ever run. "

ax/core/batch_trial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def in_design_arms(self) -> list[Arm]:
396396
if self.experiment.search_space.check_membership(arm.parameters)
397397
]
398398

399-
# pyre-ignore[6]: T77111662.
399+
# pyre-ignore[6]: pyre does not understand @copy_doc with @property.
400400
@copy_doc(BaseTrial.generator_runs)
401401
@property
402402
def generator_runs(self) -> list[GeneratorRun]:

ax/core/data.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import cached_property
1717
from io import StringIO
1818
from logging import Logger
19-
from typing import Any
19+
from typing import Any, cast
2020

2121
import numpy as np
2222
import numpy.typing as npt
@@ -177,29 +177,24 @@ def __init__(
177177
f"Dataframe must contain required columns {list(missing_columns)}."
178178
)
179179

180+
# Using itertuples() instead of iterrows() for speed.
181+
# cast() suppresses pyre errors on namedtuple attribute access.
180182
self._data_rows = [
181183
DataRow(
182-
# pyre-ignore[16] Intentional unsafe namedtuple access
183184
# int() cast needed because pd.read_json with dtype=False
184185
# can return string trial indices from storage
185186
trial_index=int(row.trial_index),
186-
# pyre-ignore[16] Intentional unsafe namedtuple access
187187
arm_name=row.arm_name,
188-
# pyre-ignore[16] Intentional unsafe namedtuple access
189188
metric_name=row.metric_name,
190-
# pyre-ignore[16] Intentional unsafe namedtuple access
191189
metric_signature=row.metric_signature,
192-
# pyre-ignore[16] Intentional unsafe namedtuple access
193190
mean=row.mean,
194-
# pyre-ignore[16] Intentional unsafe namedtuple access
195191
se=row.sem,
196192
step=getattr(row, "step", None),
197193
start_time=getattr(row, "start_time", None),
198194
end_time=getattr(row, "end_time", None),
199195
n=getattr(row, "n", None),
200196
)
201-
# Using itertuples() instead of iterrows() for speed
202-
for row in df.itertuples(index=False)
197+
for row in cast(Any, df.itertuples(index=False))
203198
]
204199
else:
205200
self._data_rows = []

ax/core/evaluations_to_data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from collections.abc import Mapping
1010
from enum import Enum
11+
from typing import cast
1112

1213
from ax.core.data import Data, DataRow
1314
from ax.core.types import FloatLike, SingleMetricData, TEvaluationOutcome
@@ -117,9 +118,11 @@ def raw_evaluations_to_data(
117118
"multiple metrics."
118119
)
119120
metric_name = next(iter(metric_name_to_signature.keys()))
120-
# pyre-fixme[6]: Incmopatible parameter type (Pyre doesn't know that
121-
# this is in fact a SingleMetricData)
122-
mean, sem = _validate_and_extract_single_metric_data(dat=evaluation)
121+
# After eliminating dict and list cases above, evaluation is
122+
# SingleMetricData, but pyre can't narrow union type aliases.
123+
mean, sem = _validate_and_extract_single_metric_data(
124+
dat=cast(SingleMetricData, evaluation)
125+
)
123126
metric_names_seen.add(metric_name)
124127
data_rows.append(
125128
DataRow(

ax/core/experiment.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@
7272

7373
ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES: int = 6
7474

75-
# pyre-fixme[5]: Global expression must be annotated.
76-
round_floats_for_logging = partial(
75+
round_floats_for_logging: partial[Any] = partial(
7776
_round_floats_for_logging,
7877
decimal_places=ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES,
7978
)
@@ -135,9 +134,7 @@ def __init__(
135134
DeprecationWarning,
136135
stacklevel=2,
137136
)
138-
# appease pyre
139-
# pyre-fixme[13]: Attribute `_search_space` is never initialized.
140-
self._search_space: SearchSpace
137+
self._search_space: SearchSpace = search_space
141138
self._status_quo: Arm | None = None
142139

143140
self._name = name
@@ -194,7 +191,6 @@ def __init__(
194191
self.add_tracking_metrics(tracking_metrics or [])
195192

196193
# call setters defined below
197-
self.search_space: SearchSpace = search_space
198194
self.status_quo = status_quo
199195
if optimization_config is not None:
200196
self.optimization_config = optimization_config

ax/core/metric.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from datetime import timedelta
1616
from functools import reduce
1717
from logging import Logger
18-
from typing import Any, TYPE_CHECKING
18+
from typing import Any, cast, TYPE_CHECKING
1919

2020
from ax.core.data import Data
2121
from ax.utils.common.base import SortableBase
@@ -461,18 +461,16 @@ def maybe_raise_deprecation_warning_on_class_methods(self) -> None:
461461
# This is a temporary hack to allow us to deprecate old metric class method
462462
# implementations. There does not seem to be another way of checking whether
463463
# base class' classmethods are overridden in subclasses.
464-
is_fetch_trial_data_multi_overriden = (
465-
getattr(self.__class__.fetch_trial_data_multi, "__code__", "DEFAULT")
466-
!= Metric.fetch_trial_data_multi.__code__ # pyre-ignore[16]
467-
)
468-
is_fetch_experiment_data_multi_overriden = (
469-
getattr(
470-
self.__class__.fetch_experiment_data_multi,
471-
"__code__",
472-
"DEFAULT",
473-
)
474-
!= Metric.fetch_experiment_data_multi.__code__ # pyre-ignore[16]
475-
)
464+
# Asymmetric defaults: if __code__ is missing on the subclass, "DEFAULT"
465+
# != None is True, treating the method as overridden (safe default).
466+
is_fetch_trial_data_multi_overriden = getattr(
467+
self.__class__.fetch_trial_data_multi, "__code__", "DEFAULT"
468+
) != getattr(Metric.fetch_trial_data_multi, "__code__", None)
469+
is_fetch_experiment_data_multi_overriden = getattr(
470+
self.__class__.fetch_experiment_data_multi,
471+
"__code__",
472+
"DEFAULT",
473+
) != getattr(Metric.fetch_experiment_data_multi, "__code__", None)
476474
# Raise deprecation warning if this method from the base class is used (meaning
477475
# that it is not overridden and the classmethod is overridden instead), unless
478476
# the only overridden method is `fetch_trial_data` (in which case the setup is
@@ -623,16 +621,18 @@ def _wrap_trial_data_multi(cls, data: Data) -> dict[str, MetricFetchResult]:
623621
def _wrap_experiment_data_multi(
624622
cls, data: Data
625623
) -> dict[int, dict[str, MetricFetchResult]]:
626-
# pyre-fixme[7]
627-
return {
628-
trial_index: {
629-
metric_signature: Ok(
630-
value=data.filter(
631-
trial_indices=[trial_index],
632-
metric_signatures=[metric_signature],
624+
return cast(
625+
dict[int, dict[str, MetricFetchResult]],
626+
{
627+
trial_index: {
628+
metric_signature: Ok(
629+
value=data.filter(
630+
trial_indices=[trial_index],
631+
metric_signatures=[metric_signature],
632+
)
633633
)
634-
)
635-
for metric_signature in data.full_df["metric_signature"]
636-
}
637-
for trial_index in data.full_df["trial_index"]
638-
}
634+
for metric_signature in data.full_df["metric_signature"]
635+
}
636+
for trial_index in data.full_df["trial_index"]
637+
},
638+
)

ax/core/multi_type_experiment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,12 @@ def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
109109
self._trial_type_to_runner[trial_type] = runner
110110
return self
111111

112-
# pyre-fixme [56]: Pyre was not able to infer the type of the decorator
113-
# `Experiment.optimization_config.setter`.
112+
# pyre does not support inferring the type of property setter decorators
113+
# or the `.fset` attribute on properties.
114+
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator.
114115
@Experiment.optimization_config.setter
115116
def optimization_config(self, optimization_config: OptimizationConfig) -> None:
116-
# pyre-fixme [16]: `Optional` has no attribute `fset`.
117+
# pyre-fixme[16]: `Optional` has no attribute `fset`.
117118
Experiment.optimization_config.fset(self, optimization_config)
118119
for metric_name in optimization_config.metrics.keys():
119120
# Optimization config metrics are required to be the default trial type

ax/core/parameter.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,10 @@
5555

5656

5757
class ParameterType(Enum):
58-
# pyre-fixme[35]: Target cannot be annotated.
59-
BOOL: int = 0
60-
# pyre-fixme[35]: Target cannot be annotated.
61-
INT: int = 1
62-
# pyre-fixme[35]: Target cannot be annotated.
63-
FLOAT: int = 2
64-
# pyre-fixme[35]: Target cannot be annotated.
65-
STRING: int = 3
58+
BOOL = 0
59+
INT = 1
60+
FLOAT = 2
61+
STRING = 3
6662

6763
@property
6864
def is_numeric(self) -> bool:
@@ -71,10 +67,7 @@ def is_numeric(self) -> bool:
7167

7268
TParameterType = Union[type[int], type[float], type[str], type[bool]]
7369

74-
# pyre: PARAMETER_PYTHON_TYPE_MAP is declared to have type
75-
# pyre: `Dict[ParameterType, Union[Type[bool], Type[float], Type[int],
76-
# pyre: Type[str]]]` but is used as type `Dict[ParameterType,
77-
# pyre-fixme[9]: Type[Union[float, str]]]`.
70+
# pyre-fixme[9]: Pyre collapses individual type[] values into Type[Union[...]].
7871
PARAMETER_PYTHON_TYPE_MAP: dict[ParameterType, TParameterType] = {
7972
ParameterType.INT: int,
8073
ParameterType.FLOAT: float,
@@ -87,9 +80,7 @@ def is_numeric(self) -> bool:
8780
] = tuple(PARAMETER_PYTHON_TYPE_MAP.values())
8881

8982

90-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
91-
# avoid runtime subscripting errors.
92-
def _get_parameter_type(python_type: type) -> ParameterType:
83+
def _get_parameter_type(python_type: type[Any]) -> ParameterType:
9384
"""Given a Python type, retrieve corresponding Ax ``ParameterType``."""
9485
for param_type, py_type in PARAMETER_PYTHON_TYPE_MAP.items():
9586
if py_type == python_type:
@@ -236,9 +227,8 @@ def dependents(self) -> dict[TParamValue, list[str]]:
236227
"Only fixed and choice hierarchical parameters are currently supported."
237228
)
238229

239-
# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
240-
def clone(self) -> Self:
241-
pass
230+
@abstractmethod
231+
def clone(self) -> Self: ...
242232

243233
def disable(self, default_value: TParamValue) -> None:
244234
"""
@@ -541,9 +531,7 @@ def set_digits(self, digits: int | None) -> RangeParameter:
541531
# Re-scale min and max to new digits definition
542532
cast_lower = self.cast(self._lower)
543533
cast_upper = self.cast(self._upper)
544-
# `<=` is not supported for operand types `Union[float, int]` and `int`.
545-
# pyre-ignore [58]
546-
if cast_lower >= cast_upper:
534+
if float(cast_lower) >= float(cast_upper):
547535
raise UserInputError(
548536
f"Lower bound {cast_lower} is >= upper bound {cast_upper}."
549537
)
@@ -1304,13 +1292,7 @@ class DerivedParameter(Parameter):
13041292
extendable to non-linear functions.
13051293
"""
13061294

1307-
# pyre-fixme [13]: Uninitialized attribute [13]: Attribute `_intercept` is
1308-
# declared in class `DerivedParameter` to have type `float` but is never
1309-
# initialized.
13101295
_intercept: float
1311-
# pyre-fixme [13]: Uninitialized attribute [13]: Attribute
1312-
# `_parameter_names_to_weights` is declared in class `DerivedParameter` to#
1313-
# have type `typing.Dict[str, float]` but is never initialized.
13141296
_parameter_names_to_weights: dict[str, float]
13151297

13161298
def __init__(
@@ -1343,6 +1325,8 @@ def __init__(
13431325
self._parameter_type = parameter_type # Set first so validation works
13441326
self._is_fidelity = is_fidelity
13451327
self._target_value = target_value
1328+
self._intercept = 0.0
1329+
self._parameter_names_to_weights = {}
13461330

13471331
# Parse expression and validate type constraint (reuses set_expression_str)
13481332
self.set_expression_str(expression_str)

ax/core/search_space.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections.abc import Mapping, Sequence
1414
from dataclasses import dataclass, field
1515
from logging import Logger
16+
from typing import Any
1617

1718
import numpy as np
1819
import pandas as pd
@@ -1079,7 +1080,7 @@ def cast_observation_features(
10791080
parameter should not be in the arm within the search space due to its
10801081
hierarchical structure.
10811082
"""
1082-
full_parameterization_md = {
1083+
full_parameterization_md: dict[str, Any] = {
10831084
Keys.FULL_PARAMETERIZATION: observation_features.parameters.copy()
10841085
}
10851086
obs_feats = observation_features.clone(
@@ -1089,7 +1090,7 @@ def cast_observation_features(
10891090
)
10901091
)
10911092
if not obs_feats.metadata:
1092-
obs_feats.metadata = full_parameterization_md # pyre-ignore[8]
1093+
obs_feats.metadata = full_parameterization_md
10931094
else:
10941095
obs_feats.metadata = {**obs_feats.metadata, **full_parameterization_md}
10951096

ax/core/trial.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525

2626
ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES: int = 6
2727

28-
# pyre-fixme[5]: Global expression must be annotated.
29-
round_floats_for_logging = partial(
28+
round_floats_for_logging: partial[Any] = partial(
3029
_round_floats_for_logging,
3130
decimal_places=ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES,
3231
)
@@ -115,19 +114,15 @@ def add_arm(
115114
The trial instance.
116115
"""
117116

117+
cand_metadata_by_sig: dict[str, TCandidateMetadata] | None = None
118+
if candidate_metadata is not None:
119+
cand_metadata: TCandidateMetadata = candidate_metadata.copy()
120+
cand_metadata_by_sig = {arm.signature: cand_metadata}
118121
return self.add_generator_run(
119122
generator_run=GeneratorRun(
120123
arms=[arm],
121124
type=GeneratorRunType.MANUAL.name,
122-
# pyre-ignore[6]: In call `GeneratorRun.__init__`, for 3rd parameter
123-
# `candidate_metadata_by_arm_signature`
124-
# expected `Optional[Dict[str, Optional[Dict[str, typing.Any]]]]`
125-
# but got `Optional[Dict[str, Dict[str, typing.Any]]]`
126-
candidate_metadata_by_arm_signature=(
127-
None
128-
if candidate_metadata is None
129-
else {arm.signature: candidate_metadata.copy()}
130-
),
125+
candidate_metadata_by_arm_signature=cand_metadata_by_sig,
131126
)
132127
)
133128

0 commit comments

Comments
 (0)