Skip to content

Commit ca93faa

Browse files
eonofreyfacebook-github-bot
authored andcommitted
Replace checked_cast with pyre_extensions assert_is_instance (#3229)
Summary: Pull Request resolved: #3229 Replace `checked_cast` with pyre_extensions `assert_is_instance` Search for all use cases with "ax.utils.common.typeutils .* checked_cast" regex Reviewed By: esantorella Differential Revision: D67879879 fbshipit-source-id: c009242e71aab2093d96577615b1c7c408440fd0
1 parent d03df05 commit ca93faa

File tree

87 files changed

+474
-440
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+474
-440
lines changed

ax/analysis/healthcheck/constraints_feasibility.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from ax.modelbridge.base import ModelBridge
2727
from ax.modelbridge.generation_strategy import GenerationStrategy
2828
from ax.modelbridge.transforms.derelativize import Derelativize
29-
from ax.utils.common.typeutils import checked_cast
30-
from pyre_extensions import none_throws
29+
from pyre_extensions import assert_is_instance, none_throws
3130

3231

3332
class ConstraintsFeasibilityAnalysis(HealthcheckAnalysis):
@@ -101,12 +100,8 @@ def compute(
101100
raise UserInputError(
102101
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
103102
)
104-
generation_strategy = checked_cast(
105-
GenerationStrategy,
106-
generation_strategy,
107-
exception=UserInputError(
108-
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
109-
),
103+
generation_strategy = assert_is_instance(
104+
generation_strategy, GenerationStrategy
110105
)
111106

112107
if generation_strategy.model is None:
@@ -120,8 +115,8 @@ def compute(
120115
"The current model is {model._model_key} and does not support "
121116
"prediction."
122117
)
123-
optimization_config = checked_cast(
124-
OptimizationConfig, experiment.optimization_config
118+
optimization_config = assert_is_instance(
119+
experiment.optimization_config, OptimizationConfig
125120
)
126121
constraints_feasible, df = constraints_feasibility(
127122
optimization_config=optimization_config,

ax/analysis/healthcheck/search_space_analysis.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ax.core.search_space import SearchSpace
2626
from ax.core.types import TParameterization
2727
from ax.exceptions.core import UserInputError
28-
from ax.utils.common.typeutils import checked_cast
28+
from pyre_extensions import assert_is_instance
2929

3030

3131
class SearchSpaceAnalysis(HealthcheckAnalysis):
@@ -141,7 +141,9 @@ def search_space_boundary_proportions(
141141
lower = parameter.lower
142142
upper = parameter.upper
143143
elif isinstance(parameter, ChoiceParameter) and parameter.is_ordered:
144-
values = [checked_cast(Union[int, float], v) for v in parameter.values]
144+
values = [
145+
assert_is_instance(v, Union[int, float]) for v in parameter.values
146+
]
145147
lower = min(values)
146148
upper = max(values)
147149
else:
@@ -176,7 +178,8 @@ def search_space_boundary_proportions(
176178
for pc in search_space.parameter_constraints:
177179
weighted_sums = [
178180
sum(
179-
float(checked_cast(Union[int, float], parametrization[param])) * weight
181+
float(assert_is_instance(parametrization[param], Union[int, float]))
182+
* weight
180183
for param, weight in pc.constraint_dict.items()
181184
)
182185
for parametrization in parametrizations

ax/analysis/healthcheck/tests/test_constraints_feasibility.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@
2626
from ax.modelbridge.model_spec import ModelSpec
2727
from ax.modelbridge.registry import Models
2828
from ax.utils.common.testutils import TestCase
29-
from ax.utils.common.typeutils import checked_cast
3029
from ax.utils.testing.core_stubs import get_branin_experiment_with_multi_objective
3130
from ax.utils.testing.mock import mock_botorch_optimize
32-
from pyre_extensions import none_throws
31+
from pyre_extensions import assert_is_instance, none_throws
3332

3433

3534
class TestConstraintsFeasibilityAnalysis(TestCase):
@@ -74,7 +73,7 @@ def setUp(self) -> None:
7473
sobol = get_sobol(search_space=experiment.search_space)
7574
experiment.new_batch_trial(generator_run=sobol.gen(5))
7675

77-
batch_trial = checked_cast(BatchTrial, experiment.trials[0])
76+
batch_trial = assert_is_instance(experiment.trials[0], BatchTrial)
7877

7978
batch_trial.add_arm(experiment.status_quo)
8079
batch_trial.set_status_quo_with_weight(
@@ -107,8 +106,8 @@ def setUp(self) -> None:
107106
def test_constraints_feasibility(self) -> None:
108107
self.setUp()
109108
model = none_throws(self.generation_strategy.model)
110-
optimization_config = checked_cast(
111-
OptimizationConfig, self.experiment.optimization_config
109+
optimization_config = assert_is_instance(
110+
self.experiment.optimization_config, OptimizationConfig
112111
)
113112
constraints_feasible, df_arms = constraints_feasibility(
114113
optimization_config=optimization_config,
@@ -136,8 +135,8 @@ def test_constraints_feasibility(self) -> None:
136135
experiment.attach_data(data=Data(df=df))
137136
generation_strategy._fit_current_model(data=experiment.lookup_data())
138137
model = none_throws(generation_strategy.model)
139-
optimization_config = checked_cast(
140-
OptimizationConfig, experiment.optimization_config
138+
optimization_config = assert_is_instance(
139+
experiment.optimization_config, OptimizationConfig
141140
)
142141
constraints_feasible, df_arms = constraints_feasibility(
143142
optimization_config=optimization_config, model=model
@@ -146,8 +145,8 @@ def test_constraints_feasibility(self) -> None:
146145
experiment.optimization_config = OptimizationConfig(
147146
objective=Objective(metric=Metric(name="branin_a"), minimize=False),
148147
)
149-
optimization_config = checked_cast(
150-
OptimizationConfig, experiment.optimization_config
148+
optimization_config = assert_is_instance(
149+
experiment.optimization_config, OptimizationConfig
151150
)
152151
with self.assertRaises(UserInputError):
153152
constraints_feasibility(

ax/analysis/plotly/arm_effects/predicted_effects.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from ax.modelbridge.base import ModelBridge
2727
from ax.modelbridge.generation_strategy import GenerationStrategy
2828
from ax.modelbridge.transforms.derelativize import Derelativize
29-
from ax.utils.common.typeutils import checked_cast
30-
from pyre_extensions import none_throws
29+
from pyre_extensions import assert_is_instance, none_throws
3130

3231

3332
class PredictedEffectsPlot(PlotlyAnalysis):
@@ -73,14 +72,14 @@ def compute(
7372
) -> PlotlyAnalysisCard:
7473
if experiment is None:
7574
raise UserInputError("PredictedEffectsPlot requires an Experiment.")
76-
77-
generation_strategy = checked_cast(
78-
GenerationStrategy,
79-
generation_strategy,
80-
exception=UserInputError(
75+
try:
76+
generation_strategy = assert_is_instance(
77+
generation_strategy, GenerationStrategy
78+
)
79+
except TypeError as e:
80+
raise UserInputError(
8181
"PredictedEffectsPlot requires a GenerationStrategy."
82-
),
83-
)
82+
) from e
8483

8584
try:
8685
trial_indices = [

ax/analysis/plotly/tests/test_cross_validation.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from ax.exceptions.core import UserInputError
1212
from ax.service.ax_client import AxClient, ObjectiveProperties
1313
from ax.utils.common.testutils import TestCase
14-
from ax.utils.common.typeutils import checked_cast
1514
from ax.utils.testing.mock import mock_botorch_optimize
16-
from pyre_extensions import none_throws
15+
from pyre_extensions import assert_is_instance, none_throws
1716

1817

1918
class TestCrossValidationPlot(TestCase):
@@ -71,7 +70,7 @@ def test_compute(self) -> None:
7170
# and therefore hasn't observed it
7271
if t.index == max(self.client.experiment.trials.keys()):
7372
continue
74-
arm_name = none_throws(checked_cast(Trial, t).arm).name
73+
arm_name = none_throws(assert_is_instance(t, Trial).arm).name
7574
self.assertIn(
7675
arm_name,
7776
card.df["arm_name"].unique(),
@@ -93,7 +92,7 @@ def test_it_can_specify_trial_index_correctly(self) -> None:
9392
# and therefore hasn't observed it
9493
if t.index == max(self.client.experiment.trials.keys()):
9594
continue
96-
arm_name = none_throws(checked_cast(Trial, t).arm).name
95+
arm_name = none_throws(assert_is_instance(t, Trial).arm).name
9796
self.assertIn(
9897
arm_name,
9998
card.df["arm_name"].unique(),

ax/analysis/plotly/tests/test_predicted_effects.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from ax.modelbridge.prediction_utils import predict_at_point
2020
from ax.modelbridge.registry import Models
2121
from ax.utils.common.testutils import TestCase
22-
from ax.utils.common.typeutils import checked_cast
2322
from ax.utils.testing.core_stubs import (
2423
get_branin_experiment,
2524
get_branin_metric,
@@ -28,7 +27,7 @@
2827
from ax.utils.testing.mock import mock_botorch_optimize
2928
from ax.utils.testing.modeling_stubs import get_sobol_MBM_MTGP_gs
3029
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
31-
from pyre_extensions import none_throws
30+
from pyre_extensions import assert_is_instance, none_throws
3231

3332

3433
class TestPredictedEffectsPlot(TestCase):
@@ -298,7 +297,7 @@ def test_it_works_for_non_batch_experiments(self) -> None:
298297
# THEN it has all arms represented in the dataframe
299298
for trial in experiment.trials.values():
300299
self.assertIn(
301-
none_throws(checked_cast(Trial, trial).arm).name,
300+
none_throws(assert_is_instance(trial, Trial).arm).name,
302301
card.df["arm_name"].unique(),
303302
)
304303

ax/benchmark/tests/test_benchmark_problem.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from ax.core.types import ComparisonOp
2727
from ax.exceptions.core import UserInputError
2828
from ax.utils.common.testutils import TestCase
29-
from ax.utils.common.typeutils import checked_cast
3029
from botorch.test_functions.base import ConstrainedBaseTestProblem
3130
from botorch.test_functions.multi_objective import BraninCurrin, ConstrainedBraninCurrin
3231
from botorch.test_functions.synthetic import (
@@ -225,14 +224,14 @@ def _test_constrained_from_botorch(
225224
)
226225

227226
self.assertEqual(
228-
checked_cast(BenchmarkMetric, metric).observe_noise_sd,
227+
assert_is_instance(metric, BenchmarkMetric).observe_noise_sd,
229228
observe_noise_sd,
230229
)
231230

232231
# TODO: Support observing noise variance only for some outputs
233232
for constraint in outcome_constraints:
234233
self.assertEqual(
235-
checked_cast(BenchmarkMetric, constraint.metric).observe_noise_sd,
234+
assert_is_instance(constraint.metric, BenchmarkMetric).observe_noise_sd,
236235
observe_noise_sd,
237236
)
238237

ax/benchmark/tests/test_benchmark_runner.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from ax.core.trial import Trial
3333
from ax.exceptions.core import UnsupportedError
3434
from ax.utils.common.testutils import TestCase
35-
from ax.utils.common.typeutils import checked_cast
3635
from ax.utils.testing.benchmark_stubs import (
3736
DummyTestFunction,
3837
get_jenatton_trials,
@@ -42,7 +41,7 @@
4241
from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann
4342
from botorch.utils.transforms import normalize
4443
from pandas import DataFrame
45-
from pyre_extensions import none_throws
44+
from pyre_extensions import assert_is_instance, none_throws
4645

4746

4847
class TestBenchmarkRunner(TestCase):
@@ -315,7 +314,7 @@ def test_heterogeneous_noise(self) -> None:
315314
noise_std=noise_std,
316315
)
317316
self.assertDictEqual(
318-
checked_cast(dict, runner.get_noise_stds()), noise_dict
317+
assert_is_instance(runner.get_noise_stds(), dict), noise_dict
319318
)
320319

321320
X = torch.rand(1, 6, dtype=torch.double)

ax/core/batch_trial.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
from ax.utils.common.docutils import copy_doc
3535
from ax.utils.common.equality import datetime_equals, equality_typechecker
3636
from ax.utils.common.logger import _round_floats_for_logging, get_logger
37-
from ax.utils.common.typeutils import checked_cast
38-
from pyre_extensions import none_throws
37+
from pyre_extensions import assert_is_instance, none_throws
3938

4039
logger: Logger = get_logger(__name__)
4140

@@ -490,7 +489,10 @@ def is_factorial(self) -> bool:
490489
return len(self.arms) == param_cardinality
491490

492491
def run(self) -> BatchTrial:
493-
return checked_cast(BatchTrial, super().run())
492+
return assert_is_instance(
493+
super().run(),
494+
BatchTrial,
495+
)
494496

495497
def normalized_arm_weights(
496498
self, total: float = 1, trunc_digits: int | None = None

ax/core/data.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
TClassDecoderRegistry,
2828
TDecoderRegistry,
2929
)
30-
from ax.utils.common.typeutils import checked_cast
31-
from pyre_extensions import none_throws
30+
from pyre_extensions import assert_is_instance, none_throws
3231

3332
TBaseData = TypeVar("TBaseData", bound="BaseData")
3433
DF_REPR_MAX_LENGTH = 1000
@@ -148,7 +147,7 @@ def _safecast_df(
148147
and coltype is not Any
149148
}
150149

151-
return checked_cast(pd.DataFrame, df.astype(dtype=dtype))
150+
return assert_is_instance(df.astype(dtype=dtype), pd.DataFrame)
152151

153152
@classmethod
154153
def required_columns(cls) -> set[str]:
@@ -194,7 +193,7 @@ def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
194193
"""Serialize the class-dependent properties needed to initialize this Data.
195194
Used for storage and to help construct new similar Data.
196195
"""
197-
data = checked_cast(cls, obj)
196+
data = assert_is_instance(obj, cls)
198197
return serialize_init_args(
199198
obj=data, exclude_fields=["_skip_ordering_and_validation"]
200199
)

ax/core/experiment.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@
5555
from ax.utils.common.logger import _round_floats_for_logging, get_logger
5656
from ax.utils.common.result import Err, Ok
5757
from ax.utils.common.timeutils import current_timestamp_in_millis
58-
from ax.utils.common.typeutils import checked_cast
59-
from pyre_extensions import none_throws
58+
from pyre_extensions import assert_is_instance, none_throws
6059

6160
logger: logging.Logger = get_logger(__name__)
6261

@@ -1389,8 +1388,8 @@ def warm_start_from_old_experiment(
13891388
old_data = (
13901389
old_experiment.default_data_constructor(
13911390
df=new_df,
1392-
map_key_infos=checked_cast(
1393-
MapData, old_experiment.lookup_data()
1391+
map_key_infos=assert_is_instance(
1392+
old_experiment.lookup_data(), MapData
13941393
).map_key_infos,
13951394
)
13961395
if old_experiment.default_data_type == DataType.MAP_DATA
@@ -1603,7 +1602,7 @@ def attach_trial(
16031602
# data for this arm "complete" in the flattened search space.
16041603
candidate_metadata = None
16051604
if self.search_space.is_hierarchical:
1606-
hss = checked_cast(HierarchicalSearchSpace, self.search_space)
1605+
hss = assert_is_instance(self.search_space, HierarchicalSearchSpace)
16071606
candidate_metadata = hss.cast_observation_features(
16081607
observation_features=hss.flatten_observation_features(
16091608
observation_features=observation.ObservationFeatures(
@@ -1785,7 +1784,9 @@ def metric_config_summary_df(self) -> pd.DataFrame:
17851784
if self.optimization_config is not None:
17861785
opt_config = self.optimization_config
17871786
if self.is_moo_problem:
1788-
multi_objective = checked_cast(MultiObjective, opt_config.objective)
1787+
multi_objective = assert_is_instance(
1788+
opt_config.objective, MultiObjective
1789+
)
17891790
objectives = multi_objective.objectives
17901791
else:
17911792
objectives = [opt_config.objective]

ax/core/map_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
TClassDecoderRegistry,
2929
TDecoderRegistry,
3030
)
31-
from ax.utils.common.typeutils import checked_cast
31+
from pyre_extensions import assert_is_instance
3232

3333
logger: Logger = get_logger(__name__)
3434

@@ -327,7 +327,7 @@ def filter(
327327
@classmethod
328328
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
329329
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
330-
map_data = checked_cast(MapData, obj)
330+
map_data = assert_is_instance(obj, MapData)
331331
properties = serialize_init_args(
332332
obj=map_data, exclude_fields=["_skip_ordering_and_validation"]
333333
)

ax/core/observation.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
from ax.utils.common.base import Base
2929
from ax.utils.common.constants import Keys
3030
from ax.utils.common.logger import get_logger
31-
from ax.utils.common.typeutils import checked_cast
32-
from pyre_extensions import none_throws
31+
from pyre_extensions import assert_is_instance, none_throws
3332

3433
logger: Logger = get_logger(__name__)
3534

@@ -429,7 +428,7 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]:
429428
# use observations_from_map_data, which is required
430429
# to properly handle MapData features (e.g. fidelity).
431430
if is_map_data:
432-
data = checked_cast(MapData, data)
431+
data = assert_is_instance(data, MapData)
433432
feature_cols = feature_cols.union(data.map_keys)
434433

435434
for column in TIME_COLS:

0 commit comments

Comments
 (0)