Skip to content

Commit 6560861

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Remove pyre-fixme/pyre-ignore from ax/storage/ source files (facebook#4988)
Summary: Pull Request resolved: facebook#4988 Remove ~124 pyre-fixme/pyre-ignore suppression comments from 22 source files in ax/storage/ by applying proper type fixes: - Use `cast(type[SQAClass], ...)` for SQA class lookups from config dicts - Use `cast(type[Enum], enum)` for enum value/name access - Change bare `type` to `type[Any]` in registry function signatures - Use `assert_is_instance()` for JSON dict key narrowing - Add proper type annotations for Generator return types - Use `none_throws()` for generation strategy ID access - Fix SQLAlchemy TypeDecorator parameter types Remaining pyre errors are pre-existing SQLAlchemy/BoTorch stub mismatches that cannot be fixed without changing library type stubs. Reviewed By: dme65 Differential Revision: D95264795 fbshipit-source-id: 29cdfa255b929dfaeaf877b88e34850a8a343f03
1 parent eb264f1 commit 6560861

File tree

19 files changed

+131
-183
lines changed

19 files changed

+131
-183
lines changed

ax/storage/json_store/decoder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ def _criterion_from_json(
474474
for key, value in object_json.items()
475475
}
476476
init_args = extract_init_args(args=decoded, class_=criterion_class)
477-
# pyre-ignore[45]: Class passed is always a concrete subclass.
478477
return criterion_class(**init_args)
479478

480479

ax/storage/json_store/decoders.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,9 @@ def choice_parameter_from_json(
454454
# JSON converts dictionary keys to strings. We need to convert them back.
455455
if dependents is not None:
456456
dependents = {
457-
# pyre-ignore [6]: JSON keys are always strings
458-
string_to_parameter_value(s=key, parameter_type=parameter_type): value
457+
string_to_parameter_value(
458+
s=assert_is_instance(key, str), parameter_type=parameter_type
459+
): value
459460
for key, value in dependents.items()
460461
}
461462

@@ -499,8 +500,9 @@ def fixed_parameter_from_json(
499500
# JSON converts dictionary keys to strings. We need to convert them back.
500501
if dependents is not None:
501502
dependents = {
502-
# pyre-ignore [6]: JSON keys are always strings
503-
string_to_parameter_value(s=key, parameter_type=parameter_type): value
503+
string_to_parameter_value(
504+
s=assert_is_instance(key, str), parameter_type=parameter_type
505+
): value
504506
for key, value in dependents.items()
505507
}
506508

ax/storage/json_store/encoder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@
3131

3232
def object_to_json(
3333
obj: Any,
34-
# pyre-ignore[24]: Missing parameter annotation, Invalid type parameters
3534
encoder_registry: dict[
36-
type, Callable[[Any], dict[str, Any]]
35+
type[Any], Callable[[Any], dict[str, Any]]
3736
] = CORE_ENCODER_REGISTRY,
38-
# pyre-ignore[24]: Missing parameter annotation, Invalid type parameters
3937
class_encoder_registry: dict[
40-
type, Callable[[Any], dict[str, Any]]
38+
type[Any], Callable[[Any], dict[str, Any]]
4139
] = CORE_CLASS_ENCODER_REGISTRY,
4240
) -> Any:
4341
"""Convert an Ax object to a JSON-serializable dictionary.

ax/storage/json_store/encoders.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import warnings
1010
from pathlib import Path
11-
from typing import Any
11+
from typing import Any, cast
1212

1313
from ax.adapter.transforms.base import Transform
1414
from ax.core import Experiment, ObservationFeatures
@@ -74,6 +74,7 @@
7474
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
7575
from botorch.sampling.base import MCSampler
7676
from botorch.utils.types import _DefaultType
77+
from pyre_extensions import assert_is_instance
7778
from torch import Tensor
7879

7980

@@ -397,10 +398,13 @@ def transform_type_to_dict(transform_type: type[Transform]) -> dict[str, Any]:
397398

398399

399400
def generation_step_to_dict(generation_step: GenerationStep) -> dict[str, Any]:
400-
"""Converts Ax generation step to a dictionary."""
401-
# pyre-fixme[6]: Currently, Pyre doesn't recognize that `Generation
402-
# Step.__new__` actually returns a `GenerationNode`.
403-
return generation_node_to_dict(generation_node=generation_step)
401+
"""Converts Ax generation step to a dictionary.
402+
403+
Note: ``GenerationStep.__new__`` actually returns a ``GenerationNode``.
404+
"""
405+
return generation_node_to_dict(
406+
generation_node=cast(GenerationNode, generation_step)
407+
)
404408

405409

406410
def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]:
@@ -582,14 +586,14 @@ def botorch_input_transform_to_init_args(
582586
if isinstance(input_transform, ChainedInputTransform):
583587
return {k: botorch_component_to_dict(v) for k, v in input_transform.items()}
584588
else:
585-
try:
586-
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
587-
return input_transform.get_init_args()
588-
except AttributeError:
589+
if not hasattr(input_transform, "get_init_args"):
589590
raise JSONEncodeError(
590591
f"{input_transform.__class__.__name__} does not define `get_init_args` "
591592
"method. Please implement it to enable storage."
592593
)
594+
# pyre-fixme[29]: `Union[Tensor, Module]` is not callable; hasattr guards
595+
# this but pyre can't narrow the Union type.
596+
return assert_is_instance(input_transform, InputTransform).get_init_args()
593597

594598

595599
def percentile_early_stopping_strategy_to_dict(

ax/storage/json_store/registry.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@
184184
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior
185185

186186

187-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
188-
# avoid runtime subscripting errors.
189-
CORE_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = {
187+
CORE_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = {
190188
Arm: arm_to_dict,
191189
AuxiliaryExperiment: auxiliary_experiment_to_dict,
192190
AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
@@ -269,9 +267,7 @@
269267
# NOTE: Avoid putting a class along with its subclass in `CLASS_ENCODER_REGISTRY`.
270268
# The encoder iterates through this dictionary and uses the first superclass that
271269
# it finds, which might not be the intended superclass.
272-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
273-
# avoid runtime subscripting errors.
274-
CORE_CLASS_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = {
270+
CORE_CLASS_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = {
275271
Acquisition: botorch_modular_to_dict, # Ax MBM component
276272
AcquisitionFunction: botorch_modular_to_dict, # BoTorch component
277273
InputTransform: botorch_modular_to_dict, # BoTorch input transform component

ax/storage/json_store/save.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,11 @@
2121
def save_experiment(
2222
experiment: Experiment,
2323
filepath: str,
24-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
25-
# `typing.Type` to avoid runtime subscripting errors.
2624
encoder_registry: dict[
27-
type, Callable[[Any], dict[str, Any]]
25+
type[Any], Callable[[Any], dict[str, Any]]
2826
] = CORE_ENCODER_REGISTRY,
29-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
30-
# `typing.Type` to avoid runtime subscripting errors.
3127
class_encoder_registry: dict[
32-
type, Callable[[Any], dict[str, Any]]
28+
type[Any], Callable[[Any], dict[str, Any]]
3329
] = CORE_CLASS_ENCODER_REGISTRY,
3430
) -> None:
3531
"""Save experiment to file.

ax/storage/metric_registry.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,13 @@
5050

5151
def register_metrics(
5252
metric_clss: dict[type[Metric], int | None],
53-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
54-
# `typing.Type` to avoid runtime subscripting errors.
5553
encoder_registry: dict[
56-
type, Callable[[Any], dict[str, Any]]
54+
type[Any], Callable[[Any], dict[str, Any]]
5755
] = CORE_ENCODER_REGISTRY,
5856
decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
59-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
60-
# avoid runtime subscripting errors.
6157
) -> tuple[
6258
dict[type[Metric], int],
63-
dict[type, Callable[[Any], dict[str, Any]]],
59+
dict[type[Any], Callable[[Any], dict[str, Any]]],
6460
TDecoderRegistry,
6561
]:
6662
"""Add custom metric classes to the SQA and JSON registries.

ax/storage/registry_bundle.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,8 @@ def __init__(
6464
self,
6565
metric_clss: dict[type[Metric], int | None],
6666
runner_clss: dict[type[Runner], int | None],
67-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
68-
# `typing.Type` to avoid runtime subscripting errors.
69-
json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]],
70-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
71-
# `typing.Type` to avoid runtime subscripting errors.
72-
json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]],
67+
json_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]],
68+
json_class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]],
7369
json_decoder_registry: TDecoderRegistry,
7470
json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]],
7571
) -> None:
@@ -79,18 +75,18 @@ def __init__(
7975
runner_clss = {
8076
k: int(v) if v is not None else None for k, v in runner_clss.items()
8177
}
82-
# pyre-fixme[4]: Attribute must be annotated.
78+
self._metric_registry: dict[type[Metric], int]
79+
self._runner_registry: dict[type[Runner], int]
80+
self._encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]]
81+
self._decoder_registry: TDecoderRegistry
8382
self._metric_registry, encoder_registry, decoder_registry = register_metrics(
8483
metric_clss=metric_clss,
8584
encoder_registry=json_encoder_registry,
8685
decoder_registry=json_decoder_registry,
8786
)
8887
(
89-
# pyre-fixme[4]: Attribute must be annotated.
9088
self._runner_registry,
91-
# pyre-fixme[4]: Attribute must be annotated.
9289
self._encoder_registry,
93-
# pyre-fixme[4]: Attribute must be annotated.
9490
self._decoder_registry,
9591
) = register_runners(
9692
runner_clss=runner_clss,
@@ -110,19 +106,17 @@ def runner_registry(self) -> dict[type[Runner], int]:
110106
return self._runner_registry
111107

112108
@property
113-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
114-
# `typing.Type` to avoid runtime subscripting errors.
115-
def encoder_registry(self) -> dict[type, Callable[[Any], dict[str, Any]]]:
109+
def encoder_registry(self) -> dict[type[Any], Callable[[Any], dict[str, Any]]]:
116110
return self._encoder_registry
117111

118112
@property
119113
def decoder_registry(self) -> TDecoderRegistry:
120114
return self._decoder_registry
121115

122116
@property
123-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
124-
# `typing.Type` to avoid runtime subscripting errors.
125-
def class_encoder_registry(self) -> dict[type, Callable[[Any], dict[str, Any]]]:
117+
def class_encoder_registry(
118+
self,
119+
) -> dict[type[Any], Callable[[Any], dict[str, Any]]]:
126120
return self._json_class_encoder_registry
127121

128122
@property
@@ -177,15 +171,11 @@ def __init__(
177171
self,
178172
metric_clss: dict[type[Metric], int | None],
179173
runner_clss: dict[type[Runner], int | None],
180-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
181-
# `typing.Type` to avoid runtime subscripting errors.
182174
json_encoder_registry: dict[
183-
type, Callable[[Any], dict[str, Any]]
175+
type[Any], Callable[[Any], dict[str, Any]]
184176
] = CORE_ENCODER_REGISTRY,
185-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
186-
# `typing.Type` to avoid runtime subscripting errors.
187177
json_class_encoder_registry: dict[
188-
type, Callable[[Any], dict[str, Any]]
178+
type[Any], Callable[[Any], dict[str, Any]]
189179
] = CORE_CLASS_ENCODER_REGISTRY,
190180
json_decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
191181
json_class_decoder_registry: dict[

ax/storage/runner_registry.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,17 @@
3131
CORE_RUNNER_REGISTRY: dict[type[Runner], int] = {SyntheticRunner: 0}
3232

3333

34-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
3534
def register_runner(
3635
runner_cls: type[Runner],
3736
runner_registry: dict[type[Runner], int] = CORE_RUNNER_REGISTRY,
38-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
39-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
40-
# `typing.Type` to avoid runtime subscripting errors.
4137
encoder_registry: dict[
42-
type, Callable[[Any], dict[str, Any]]
38+
type[Any], Callable[[Any], dict[str, Any]]
4339
] = CORE_ENCODER_REGISTRY,
4440
decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
4541
val: int | None = None,
46-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
47-
# avoid runtime subscripting errors.
4842
) -> tuple[
4943
dict[type[Runner], int],
50-
dict[type, Callable[[Any], dict[str, Any]]],
44+
dict[type[Any], Callable[[Any], dict[str, Any]]],
5145
TDecoderRegistry,
5246
]:
5347
"""Add a custom runner class to the SQA and JSON registries.
@@ -62,22 +56,16 @@ def register_runner(
6256
return new_runner_registry, new_encoder_registry, new_decoder_registry
6357

6458

65-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
6659
def register_runners(
6760
runner_clss: dict[type[Runner], int | None],
6861
runner_registry: dict[type[Runner], int] = CORE_RUNNER_REGISTRY,
69-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
70-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
71-
# `typing.Type` to avoid runtime subscripting errors.
7262
encoder_registry: dict[
73-
type, Callable[[Any], dict[str, Any]]
63+
type[Any], Callable[[Any], dict[str, Any]]
7464
] = CORE_ENCODER_REGISTRY,
7565
decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
76-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
77-
# avoid runtime subscripting errors.
7866
) -> tuple[
7967
dict[type[Runner], int],
80-
dict[type, Callable[[Any], dict[str, Any]]],
68+
dict[type[Any], Callable[[Any], dict[str, Any]]],
8169
TDecoderRegistry,
8270
]:
8371
"""Add custom runner classes to the SQA and JSON registries.

ax/storage/sqa_store/db.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
LONGTEXT_BYTES: int = 2**32 - 1
3434

3535
# global database variables
36-
SESSION_FACTORY: Session | None = None
36+
SESSION_FACTORY: scoped_session | None = None
3737

3838
# set this to false to prevent SQLAlchemy for automatically expiring objects
3939
# on commit, which essentially makes them unusable outside of a session
@@ -235,7 +235,6 @@ def get_session() -> Session:
235235
if SESSION_FACTORY is None:
236236
init_engine_and_session_factory()
237237
assert SESSION_FACTORY is not None
238-
# pyre-fixme[29]: `Session` is not a function.
239238
return SESSION_FACTORY()
240239

241240

0 commit comments

Comments
 (0)