Skip to content

Commit 3c68bd3

Browse files
eonofreyfacebook-github-bot
authored andcommitted
Delete checked_cast and replace `checked_cast_(list|dict|to_tuple|optional) (#3230)
Summary: Pull Request resolved: #3230 Make the below replacements: `checked_cast_list` -> `assert_is_instance_list` `checked_cast_dict` -> `assert_is_instance_dict` `checked_cast_to_tuple` -> `assert_is_instance_of_tuple` `checked_cast_optional` -> `assert_is_instance_optional` `_argparse_type_encoder` untouched Reviewed By: danielcohenlive Differential Revision: D67993468 fbshipit-source-id: b5956a6fc9a81a6516d24a762c6e1257c3cb53f4
1 parent 33ae26a commit 3c68bd3

File tree

13 files changed

+114
-112
lines changed

13 files changed

+114
-112
lines changed

ax/modelbridge/generation_strategy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ax.modelbridge.model_spec import FactoryFunctionModelSpec
3535
from ax.modelbridge.transition_criterion import TrialBasedCriterion
3636
from ax.utils.common.logger import _round_floats_for_logging, get_logger
37-
from ax.utils.common.typeutils import checked_cast_list
37+
from ax.utils.common.typeutils import assert_is_instance_list
3838
from pyre_extensions import none_throws
3939

4040
logger: Logger = get_logger(__name__)
@@ -626,7 +626,7 @@ def clone_reset(self) -> GenerationStrategy:
626626
return GenerationStrategy(name=self.name, nodes=cloned_nodes)
627627

628628
return GenerationStrategy(
629-
name=self.name, steps=checked_cast_list(GenerationStep, cloned_nodes)
629+
name=self.name, steps=assert_is_instance_list(cloned_nodes, GenerationStep)
630630
)
631631

632632
def _unset_non_persistent_state_fields(self) -> None:

ax/modelbridge/modelbridge_utils.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@
5656
pareto_frontier_evaluator,
5757
)
5858
from ax.utils.common.logger import get_logger
59-
from ax.utils.common.typeutils import checked_cast_optional, checked_cast_to_tuple
59+
from ax.utils.common.typeutils import (
60+
assert_is_instance_of_tuple,
61+
assert_is_instance_optional,
62+
)
6063
from botorch.acquisition.multi_objective.multi_output_risk_measures import (
6164
IndependentCVaR,
6265
IndependentVaR,
@@ -218,7 +221,9 @@ def extract_search_space_digest(
218221
if isinstance(p, ChoiceParameter):
219222
if p.is_task:
220223
task_features.append(i)
221-
target_values[i] = checked_cast_to_tuple((int, float), p.target_value)
224+
target_values[i] = assert_is_instance_of_tuple(
225+
p.target_value, (int, float)
226+
)
222227
elif p.is_ordered:
223228
ordinal_features.append(i)
224229
else:
@@ -243,7 +248,7 @@ def extract_search_space_digest(
243248
raise ValueError(f"Unknown parameter type {type(p)}")
244249
if p.is_fidelity:
245250
fidelity_features.append(i)
246-
target_values[i] = checked_cast_to_tuple((int, float), p.target_value)
251+
target_values[i] = assert_is_instance_of_tuple(p.target_value, (int, float))
247252

248253
return SearchSpaceDigest(
249254
feature_names=param_names,
@@ -1054,8 +1059,8 @@ def _get_multiobjective_optimization_config(
10541059
objective_thresholds: TRefPoint | None = None,
10551060
) -> MultiObjectiveOptimizationConfig:
10561061
# Optimization_config
1057-
mooc = optimization_config or checked_cast_optional(
1058-
MultiObjectiveOptimizationConfig, modelbridge._optimization_config
1062+
mooc = optimization_config or assert_is_instance_optional(
1063+
modelbridge._optimization_config, MultiObjectiveOptimizationConfig
10591064
)
10601065
if not mooc:
10611066
raise ValueError(

ax/modelbridge/transforms/power_transform_y.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated
2323
from ax.models.types import TConfig
2424
from ax.utils.common.logger import get_logger
25-
from ax.utils.common.typeutils import checked_cast_list
25+
from ax.utils.common.typeutils import assert_is_instance_list
2626
from pyre_extensions import assert_is_instance
2727
from sklearn.preprocessing import PowerTransformer
2828

@@ -216,5 +216,5 @@ def _compute_inverse_bounds(
216216
bounds[1] = (-1.0 / lambda_ - mu) / sigma
217217
elif lambda_ > 2.0 + tol:
218218
bounds[0] = (1.0 / (2.0 - lambda_) - mu) / sigma
219-
inv_bounds[k] = tuple(checked_cast_list(float, bounds))
219+
inv_bounds[k] = tuple(assert_is_instance_list(bounds, float))
220220
return inv_bounds

ax/models/random/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ax.models.types import TConfig
2525
from ax.utils.common.docutils import copy_doc
2626
from ax.utils.common.logger import get_logger
27-
from ax.utils.common.typeutils import checked_cast_to_tuple
27+
from ax.utils.common.typeutils import assert_is_instance_of_tuple
2828
from botorch.utils.sampling import HitAndRunPolytopeSampler
2929
from pyre_extensions import assert_is_instance
3030
from torch import Tensor
@@ -129,7 +129,7 @@ def gen(
129129
if model_gen_options:
130130
max_draws = model_gen_options.get("max_rs_draws")
131131
if max_draws is not None:
132-
max_draws = int(checked_cast_to_tuple((int, float), max_draws))
132+
max_draws = int(assert_is_instance_of_tuple(max_draws, (int, float)))
133133
try:
134134
# Always rejection sample, but this only rejects if there are
135135
# constraints or actual duplicates and deduplicate is specified.

ax/models/torch/botorch_modular/surrogate.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@
5151
from ax.utils.common.base import Base
5252
from ax.utils.common.constants import Keys
5353
from ax.utils.common.logger import get_logger
54-
from ax.utils.common.typeutils import _argparse_type_encoder, checked_cast_optional
54+
from ax.utils.common.typeutils import (
55+
_argparse_type_encoder,
56+
assert_is_instance_optional,
57+
)
5558
from ax.utils.stats.model_fit_stats import (
5659
DIAGNOSTIC_FN_DIRECTIONS,
5760
DIAGNOSTIC_FNS,
@@ -1277,7 +1280,9 @@ def best_out_of_sample_point(
12771280
options = options or {}
12781281
acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class(
12791282
outcome_constraints=torch_opt_config.outcome_constraints,
1280-
seed_inner=checked_cast_optional(int, options.get(Keys.SEED_INNER, None)),
1283+
seed_inner=assert_is_instance_optional(
1284+
options.get(Keys.SEED_INNER, None), int
1285+
),
12811286
qmc=assert_is_instance(
12821287
options.get(Keys.QMC, True),
12831288
bool,

ax/plot/scatter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
TNullableGeneratorRunsDict,
4545
)
4646
from ax.utils.common.logger import get_logger
47-
from ax.utils.common.typeutils import checked_cast_optional
47+
from ax.utils.common.typeutils import assert_is_instance_optional
4848
from ax.utils.stats.statstools import relativize
4949
from plotly import subplots
5050

@@ -419,7 +419,7 @@ def plot_multiple_metrics(
419419
layout_offset_x = 0.15
420420
else:
421421
layout_offset_x = 0
422-
rel = checked_cast_optional(bool, kwargs.get("rel"))
422+
rel = assert_is_instance_optional(kwargs.get("rel"), bool)
423423
if rel is not None:
424424
warnings.warn(
425425
"Use `rel_x` and `rel_y` instead of `rel`.",

ax/service/tests/test_instantiation_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,10 @@ def test_choice_with_is_sorted(self) -> None:
385385
else:
386386
self.assertEqual(output.sort_values, sort_values)
387387

388-
with self.assertRaisesRegex(ValueError, "Value was not of type <class 'bool'>"):
388+
with self.assertRaisesRegex(
389+
TypeError,
390+
r"obj is not an instance of cls: obj=\['Foo'\] cls=<class 'bool'>",
391+
):
389392
representation: dict[str, Any] = {
390393
"name": "foo_or_bar",
391394
"type": "choice",

ax/service/utils/instantiation.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
from ax.exceptions.core import UnsupportedError
4848
from ax.utils.common.constants import Keys
4949
from ax.utils.common.logger import get_logger
50-
from ax.utils.common.typeutils import checked_cast_optional, checked_cast_to_tuple
50+
from ax.utils.common.typeutils import (
51+
assert_is_instance_of_tuple,
52+
assert_is_instance_optional,
53+
)
5154
from pyre_extensions import assert_is_instance, none_throws
5255

5356
DEFAULT_OBJECTIVE_NAME = "objective"
@@ -227,8 +230,8 @@ def _make_range_param(
227230
parameter_type=cls._to_parameter_type(
228231
bounds, parameter_type, name, "bounds"
229232
),
230-
lower=checked_cast_to_tuple((float, int), bounds[0]),
231-
upper=checked_cast_to_tuple((float, int), bounds[1]),
233+
lower=assert_is_instance_of_tuple(bounds[0], (float, int)),
234+
upper=assert_is_instance_of_tuple(bounds[1], (float, int)),
232235
log_scale=assert_is_instance(representation.get("log_scale", False), bool),
233236
digits=representation.get("digits", None), # pyre-ignore[6]
234237
is_fidelity=assert_is_instance(
@@ -258,17 +261,19 @@ def _make_choice_param(
258261
values, parameter_type, name, "values"
259262
),
260263
values=values,
261-
is_ordered=checked_cast_optional(bool, representation.get("is_ordered")),
264+
is_ordered=assert_is_instance_optional(
265+
representation.get("is_ordered"), bool
266+
),
262267
is_fidelity=assert_is_instance(
263268
representation.get("is_fidelity", False), bool
264269
),
265270
is_task=assert_is_instance(representation.get("is_task", False), bool),
266271
target_value=representation.get("target_value", None), # pyre-ignore[6]
267-
sort_values=checked_cast_optional(
268-
bool, representation.get("sort_values", None)
272+
sort_values=assert_is_instance_optional(
273+
representation.get("sort_values", None), bool
269274
),
270-
dependents=checked_cast_optional(
271-
dict, representation.get("dependents", None)
275+
dependents=assert_is_instance_optional(
276+
representation.get("dependents", None), dict
272277
),
273278
)
274279

ax/utils/common/tests/test_typeutils.py

+25-32
Original file line numberDiff line numberDiff line change
@@ -10,43 +10,36 @@
1010
import numpy as np
1111
from ax.utils.common.testutils import TestCase
1212
from ax.utils.common.typeutils import (
13-
checked_cast,
14-
checked_cast_dict,
15-
checked_cast_list,
16-
checked_cast_optional,
13+
assert_is_instance_dict,
14+
assert_is_instance_list,
15+
assert_is_instance_optional,
1716
)
1817
from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type
18+
from pyre_extensions import assert_is_instance
1919

2020

2121
class TestTypeUtils(TestCase):
22-
def test_checked_cast(self) -> None:
23-
self.assertEqual(checked_cast(float, 2.0), 2.0)
24-
with self.assertRaises(ValueError):
25-
checked_cast(float, 2)
26-
27-
def test_checked_cast_with_error_override(self) -> None:
28-
self.assertEqual(checked_cast(float, 2.0), 2.0)
29-
with self.assertRaises(NotImplementedError):
30-
checked_cast(
31-
float, 2, exception=NotImplementedError("foo() doesn't support ints")
32-
)
33-
34-
def test_checked_cast_list(self) -> None:
35-
self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0])
36-
with self.assertRaises(ValueError):
37-
checked_cast_list(float, [1.0, 2])
38-
39-
def test_checked_cast_optional(self) -> None:
40-
self.assertEqual(checked_cast_optional(float, None), None)
41-
with self.assertRaises(ValueError):
42-
checked_cast_optional(float, 2)
43-
44-
def test_checked_cast_dict(self) -> None:
45-
self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1})
46-
with self.assertRaises(ValueError):
47-
checked_cast_dict(str, int, {"some": 1.0})
48-
with self.assertRaises(ValueError):
49-
checked_cast_dict(str, int, {1: 1})
22+
def test_assert_is_instance(self) -> None:
23+
self.assertEqual(assert_is_instance(2.0, float), 2.0)
24+
with self.assertRaises(TypeError):
25+
assert_is_instance(2, float)
26+
27+
def test_assert_is_instance_list(self) -> None:
28+
self.assertEqual(assert_is_instance_list([1.0, 2.0], float), [1.0, 2.0])
29+
with self.assertRaises(TypeError):
30+
assert_is_instance_list([1.0, 2], float)
31+
32+
def test_assert_is_instance_optional(self) -> None:
33+
self.assertEqual(assert_is_instance_optional(None, float), None)
34+
with self.assertRaises(TypeError):
35+
assert_is_instance_optional(2, float)
36+
37+
def test_assert_is_instance_dict(self) -> None:
38+
self.assertEqual(assert_is_instance_dict({"some": 1}, str, int), {"some": 1})
39+
with self.assertRaises(TypeError):
40+
assert_is_instance_dict({"some": 1.0}, str, int)
41+
with self.assertRaises(TypeError):
42+
assert_is_instance_dict({1: 1}, str, int)
5043

5144
def test_numpy_type_to_python_type(self) -> None:
5245
self.assertEqual(type(numpy_type_to_python_type(np.int64(2))), int)

ax/utils/common/typeutils.py

+39-48
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from typing import Any, TypeVar
99

10+
from pyre_extensions import assert_is_instance
1011

1112
T = TypeVar("T")
1213
V = TypeVar("V")
@@ -15,79 +16,69 @@
1516
Y = TypeVar("Y")
1617

1718

18-
def checked_cast(typ: type[T], val: V, exception: Exception | None = None) -> T:
19+
def assert_is_instance_optional(val: V | None, typ: type[T]) -> T | None:
1920
"""
20-
Cast a value to a type (with a runtime safety check).
21-
22-
Returns the value unchanged and checks its type at runtime. This signals to the
23-
typechecker that the value has the designated type.
24-
25-
Like `typing.cast`_ ``check_cast`` performs no runtime conversion on its argument,
26-
but, unlike ``typing.cast``, ``checked_cast`` will throw an error if the value is
27-
not of the expected type. The type passed as an argument should be a python class.
21+
Asserts that the value is an instance of the given type if it is not None.
2822
2923
Args:
30-
typ: the type to cast to
31-
val: the value that we are casting
32-
exception: override exception to raise if typecheck fails
24+
val: the value to check
25+
typ: the type to check against
3326
Returns:
34-
the ``val`` argument, unchanged
35-
36-
.. _typing.cast: https://docs.python.org/3/library/typing.html#typing.cast
27+
the `val` argument, unchanged
3728
"""
38-
if not isinstance(val, typ):
39-
raise (
40-
exception
41-
if exception is not None
42-
else ValueError(f"Value was not of type {typ}:\n{val}")
43-
)
44-
return val
45-
46-
47-
def checked_cast_optional(typ: type[T], val: V | None) -> T | None:
48-
"""Calls checked_cast only if value is not None."""
4929
if val is None:
5030
return val
51-
return checked_cast(typ, val)
31+
return assert_is_instance(val, typ)
5232

5333

54-
def checked_cast_list(typ: type[T], old_l: list[V]) -> list[T]:
55-
"""Calls checked_cast on all items in a list."""
56-
new_l = []
57-
for val in old_l:
58-
val = checked_cast(typ, val)
59-
new_l.append(val)
60-
return new_l
34+
def assert_is_instance_list(old_l: list[V], typ: type[T]) -> list[T]:
35+
"""
36+
Asserts that all items in a list are instances of the given type.
6137
38+
Args:
39+
old_l: the list to check
40+
typ: the type to check against
41+
Returns:
42+
the `old_l` argument, unchanged
43+
"""
44+
return [assert_is_instance(val, typ) for val in old_l]
6245

63-
def checked_cast_dict(
64-
key_typ: type[K], value_typ: type[V], d: dict[X, Y]
46+
47+
def assert_is_instance_dict(
48+
d: dict[X, Y], key_type: type[K], val_type: type[V]
6549
) -> dict[K, V]:
66-
"""Calls checked_cast on all keys and values in the dictionary."""
50+
"""
51+
Asserts that all keys and values in the dictionary are instances
52+
of the given classes.
53+
54+
Args:
55+
d: the dictionary to check
56+
key_type: the type to check against for keys
57+
val_type: the type to check against for values
58+
Returns:
59+
the `d` argument, unchanged
60+
"""
6761
new_dict = {}
6862
for key, val in d.items():
69-
val = checked_cast(value_typ, val)
70-
key = checked_cast(key_typ, key)
63+
key = assert_is_instance(key, key_type)
64+
val = assert_is_instance(val, val_type)
7165
new_dict[key] = val
7266
return new_dict
7367

7468

7569
# pyre-fixme[34]: `T` isn't present in the function's parameters.
76-
def checked_cast_to_tuple(typ: tuple[type[V], ...], val: V) -> T:
70+
def assert_is_instance_of_tuple(val: V, typ: tuple[type[V], ...]) -> T:
7771
"""
78-
Cast a value to a union of multiple types (with a runtime safety check).
79-
This function is similar to `checked_cast`, but allows for the type to be
80-
defined as a tuple of types, in which case the value is cast as a union of
81-
the types in the tuple.
72+
Asserts that a value is an instance of any type in a tuple of types.
8273
8374
Args:
84-
typ: the tuple of types to cast to
85-
val: the value that we are casting
75+
typ: the tuple of types to check against
76+
val: the value that we are checking
8677
Returns:
87-
the ``val`` argument, unchanged
78+
the `val` argument, unchanged
8879
"""
8980
if not isinstance(val, typ):
90-
raise ValueError(f"Value was not of type {type!r}:\n{val!r}")
81+
raise TypeError(f"Value was not of any type {typ!r}:\n{val!r}")
9182
# pyre-fixme[7]: Expected `T` but got `V`.
9283
return val
9384

0 commit comments

Comments
 (0)