Skip to content

Commit 68939aa

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
Get rid of some easy pyre-ignore (#4974)
Summary: Pull Request resolved: #4974 See title Reviewed By: saitcakmak Differential Revision: D95264008 fbshipit-source-id: d6a3dc2461027b086b3b34985c092350403d79a9
1 parent 1cd0b89 commit 68939aa

File tree

8 files changed

+43
-38
lines changed

8 files changed

+43
-38
lines changed

ax/adapter/adapter_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
3939
from ax.core.parameter_constraint import ParameterConstraint
4040
from ax.core.search_space import SearchSpace, SearchSpaceDigest
41-
from ax.core.types import TBounds, TCandidateMetadata
41+
from ax.core.types import TBounds, TCandidateMetadata, TNumeric
4242
from ax.exceptions.core import DataRequiredError, UserInputError
4343
from ax.generators.torch.botorch_moo_utils import (
4444
get_weighted_mc_objective_and_objective_thresholds,
@@ -143,8 +143,11 @@ def extract_search_space_digest(
143143
else:
144144
categorical_features.append(i)
145145
# at this point we can assume that values are numeric due to transforms
146-
discrete_choices[i] = p.values # pyre-ignore [6]
147-
bounds.append((min(p.values), max(p.values))) # pyre-ignore [6]
146+
numeric_values: list[TNumeric] = [
147+
assert_is_instance_of_tuple(v, (int, float)) for v in p.values
148+
]
149+
discrete_choices[i] = numeric_values
150+
bounds.append((min(numeric_values), max(numeric_values)))
148151
elif isinstance(p, RangeParameter):
149152
if p.log_scale or p.logit_scale:
150153
raise UserInputError(
@@ -154,8 +157,7 @@ def extract_search_space_digest(
154157
)
155158
if p.parameter_type == ParameterType.INT:
156159
ordinal_features.append(i)
157-
d_choices = list(range(int(p.lower), int(p.upper) + 1))
158-
# pyre-ignore [6]
160+
d_choices: list[TNumeric] = list(range(int(p.lower), int(p.upper) + 1))
159161
discrete_choices[i] = d_choices
160162
bounds.append((p.lower, p.upper))
161163
else:

ax/adapter/transforms/choice_encode.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,12 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
108108
dependents = None
109109
if p.is_hierarchical:
110110
# The dependents of hierarchical parameters need to be updated to
111-
# reflect the changes by encoding.
112-
dependents = {
111+
# reflect the changes by encoding. The encoded keys are ints,
112+
# which is a subtype of TParamValue.
113+
encoded_dependents: dict[TParamValue, list[str]] = {
113114
encoding[val]: deps for val, deps in p.dependents.items()
114115
}
116+
dependents = encoded_dependents
115117

116118
transformed_parameters[p_name] = ChoiceParameter(
117119
name=p_name,
@@ -130,7 +132,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
130132
# Retain the original sort_values if the parameter is not ordered.
131133
# Ordered numeric parameters are always sorted.
132134
sort_values=p.sort_values if not p.is_ordered else True,
133-
dependents=dependents, # pyre-ignore[6]
135+
dependents=dependents,
134136
)
135137
else:
136138
transformed_parameters[p.name] = p
@@ -153,7 +155,13 @@ def untransform_observation_features(
153155
if p_name in obsf.parameters:
154156
# Rounding & casting to int in case a floating point value was
155157
# generated. This can happen since generation uses float tensors.
156-
pval = int(round(obsf.parameters[p_name])) # pyre-ignore [6]
158+
# The value can be int (from data) or float (from generation).
159+
param_val = obsf.parameters[p_name]
160+
pval = (
161+
int(round(param_val))
162+
if isinstance(param_val, float)
163+
else assert_is_instance(param_val, int)
164+
)
157165
if pval in reverse_transform:
158166
obsf.parameters[p_name] = reverse_transform[pval]
159167
return observation_features

ax/adapter/transforms/int_to_float.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ def transform_observation_features(
9999
for obsf in observation_features:
100100
for p_name in self.transform_parameters:
101101
if p_name in obsf.parameters:
102-
# pyre: param is declared to have type `int` but is used
103-
# pyre-fixme[9]: as type `Optional[typing.Union[bool, float, str]]`.
104-
param: int = obsf.parameters[p_name]
102+
param = assert_is_instance(obsf.parameters[p_name], int)
105103
obsf.parameters[p_name] = float(param)
106104
return observation_features
107105

@@ -146,9 +144,7 @@ def untransform_observation_features(
146144
)
147145
if self.rounding == "strict":
148146
for p_name in present_params:
149-
# pyre: param is declared to have type `float` but is used as
150-
# pyre-fixme[9]: type `Optional[typing.Union[bool, float, str]]`.
151-
param: float = obsf.parameters.get(p_name)
147+
param = assert_is_instance(obsf.parameters.get(p_name), float)
152148
obsf.parameters[p_name] = int(round(param)) # TODO: T41938776
153149
else:
154150
if self.contains_constrained_integer:
@@ -191,8 +187,8 @@ def untransform_observation_features(
191187
# that satisfies the search space bounds, but this candidate may
192188
# not satisfy the parameter constraints.
193189
for p_name in present_params:
194-
param = obsf.parameters.get(p_name)
195-
obsf.parameters[p_name] = int(round(param)) # pyre-ignore
190+
param = assert_is_instance(obsf.parameters.get(p_name), float)
191+
obsf.parameters[p_name] = int(round(param))
196192
else: # Update observation if rounding was successful
197193
for p_name in present_params:
198194
obsf.parameters[p_name] = rounded_parameters[p_name]

ax/adapter/transforms/logit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ax.core.parameter import ParameterType, RangeParameter
1717
from ax.core.search_space import SearchSpace
1818
from ax.generators.types import TConfig
19+
from pyre_extensions import assert_is_instance
1920
from scipy.special import expit, logit
2021

2122
if TYPE_CHECKING:
@@ -58,7 +59,7 @@ def transform_observation_features(
5859
for obsf in observation_features:
5960
for p_name in self.transform_parameters:
6061
if p_name in obsf.parameters:
61-
param: float = obsf.parameters[p_name] # pyre-ignore [9]
62+
param = assert_is_instance(obsf.parameters[p_name], float)
6263
obsf.parameters[p_name] = logit(param).item()
6364
return observation_features
6465

@@ -78,7 +79,7 @@ def untransform_observation_features(
7879
for obsf in observation_features:
7980
for p_name in self.transform_parameters:
8081
if p_name in obsf.parameters:
81-
param: float = obsf.parameters[p_name] # pyre-ignore [9]
82+
param = assert_is_instance(obsf.parameters[p_name], float)
8283
obsf.parameters[p_name] = expit(param).item()
8384
return observation_features
8485

ax/adapter/transforms/metrics_as_task.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from ax.core.observation import Observation, ObservationData, ObservationFeatures
1717
from ax.core.parameter import ChoiceParameter, ParameterType
1818
from ax.core.search_space import SearchSpace
19+
from ax.core.types import TParamValueList
1920
from ax.generators.types import TConfig
21+
from pyre_extensions import assert_is_instance
2022

2123
if TYPE_CHECKING:
2224
# import as module to make sphinx-autodoc-typehints happy
@@ -58,10 +60,11 @@ def __init__(
5860
# Use config to specify metric task map
5961
if "metric_task_map" not in self.config:
6062
raise ValueError("config must specify metric_task_map")
61-
self.metric_task_map: dict[str, list[str]] = self.config[ # pyre-ignore
62-
"metric_task_map"
63-
]
64-
self.task_values: list[str] = list(self.metric_task_map.keys())
63+
self.metric_task_map: dict[str, list[str]] = assert_is_instance(
64+
self.config["metric_task_map"], dict
65+
)
66+
# Explicitly type to match ChoiceParameter.values expected type
67+
self.task_values: TParamValueList = list(self.metric_task_map.keys())
6568
assert "TARGET" not in self.task_values
6669
self.task_values.append("TARGET")
6770

@@ -136,7 +139,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
136139
task_param = ChoiceParameter(
137140
name="METRIC_TASK",
138141
parameter_type=ParameterType.STRING,
139-
values=self.task_values, # pyre-ignore
142+
values=self.task_values,
140143
is_ordered=False,
141144
is_task=True,
142145
sort_values=True,

ax/adapter/transforms/remove_fixed.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from ax.core.search_space import SearchSpace
2424
from ax.generators.types import TConfig
25+
from pyre_extensions import assert_is_instance
2526

2627
if TYPE_CHECKING:
2728
# import as module to make sphinx-autodoc-typehints happy
@@ -59,8 +60,7 @@ def find_adoptable_descendants(
5960
continue
6061
if isinstance(search_space.parameters[child], FixedParameter):
6162
lst_adoptable_descendants += find_adoptable_descendants(
62-
# pyre-ignore[6]: It's a fixed parameter for sure.
63-
search_space.parameters[child],
63+
assert_is_instance(search_space.parameters[child], FixedParameter),
6464
search_space=search_space,
6565
)
6666
else:
@@ -141,8 +141,9 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
141141
search_space.parameters[child], FixedParameter
142142
):
143143
updated_children += find_adoptable_descendants(
144-
# pyre-ignore[6]: It's a fixed parameter for sure.
145-
param=search_space.parameters[child],
144+
param=assert_is_instance(
145+
search_space.parameters[child], FixedParameter
146+
),
146147
search_space=search_space,
147148
)
148149
else:

ax/adapter/transforms/trial_as_task.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ax.generators.types import TConfig
2121
from ax.utils.common.constants import Keys
2222
from ax.utils.common.logger import get_logger
23-
from pyre_extensions import none_throws
23+
from pyre_extensions import assert_is_instance, none_throws
2424

2525
if TYPE_CHECKING:
2626
# import as module to make sphinx-autodoc-typehints happy
@@ -86,10 +86,7 @@ def __init__(
8686
)
8787
# Get trial level map
8888
if "trial_level_map" in self.config:
89-
# pyre-ignore [9]
90-
trial_level_map: dict[str, dict[int | str, int | str]] = self.config[
91-
"trial_level_map"
92-
]
89+
trial_level_map = assert_is_instance(self.config["trial_level_map"], dict)
9390
# Validate
9491
self.trial_level_map: dict[str, dict[int, int | str]] = {}
9592
for _p_name, level_dict in trial_level_map.items():

ax/benchmark/benchmark_result.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ax.utils.common.base import Base
1919
from numpy import nanmean, nanquantile
2020
from pandas import DataFrame
21+
from pyre_extensions import none_throws
2122
from scipy.stats import sem
2223

2324
PERCENTILES = [0.25, 0.5, 0.75]
@@ -188,11 +189,7 @@ def from_benchmark_results(
188189
num_trials_mean = None
189190
if all(res.num_trials is not None for res in results):
190191
num_trials_step_data = zip(
191-
*(
192-
# pyre-ignore[16]: already checked for None above
193-
res.num_trials
194-
for res in results
195-
)
192+
*(none_throws(res.num_trials) for res in results)
196193
)
197194
num_trials_mean = [nanmean(step_vals) for step_vals in num_trials_step_data]
198195

0 commit comments

Comments
 (0)