Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ def _criterion_from_json(
for key, value in object_json.items()
}
init_args = extract_init_args(args=decoded, class_=criterion_class)
# pyre-ignore[45]: Class passed is always a concrete subclass.
return criterion_class(**init_args)


Expand Down
10 changes: 6 additions & 4 deletions ax/storage/json_store/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,9 @@ def choice_parameter_from_json(
# JSON converts dictionary keys to strings. We need to convert them back.
if dependents is not None:
dependents = {
# pyre-ignore [6]: JSON keys are always strings
string_to_parameter_value(s=key, parameter_type=parameter_type): value
string_to_parameter_value(
s=assert_is_instance(key, str), parameter_type=parameter_type
): value
for key, value in dependents.items()
}

Expand Down Expand Up @@ -499,8 +500,9 @@ def fixed_parameter_from_json(
# JSON converts dictionary keys to strings. We need to convert them back.
if dependents is not None:
dependents = {
# pyre-ignore [6]: JSON keys are always strings
string_to_parameter_value(s=key, parameter_type=parameter_type): value
string_to_parameter_value(
s=assert_is_instance(key, str), parameter_type=parameter_type
): value
for key, value in dependents.items()
}

Expand Down
6 changes: 2 additions & 4 deletions ax/storage/json_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@

def object_to_json(
obj: Any,
# pyre-ignore[24]: Missing parameter annotation, Invalid type parameters
encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
# pyre-ignore[24]: Missing parameter annotation, Invalid type parameters
class_encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_CLASS_ENCODER_REGISTRY,
) -> Any:
"""Convert an Ax object to a JSON-serializable dictionary.
Expand Down
22 changes: 13 additions & 9 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import warnings
from pathlib import Path
from typing import Any
from typing import Any, cast

from ax.adapter.transforms.base import Transform
from ax.core import Experiment, ObservationFeatures
Expand Down Expand Up @@ -74,6 +74,7 @@
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
from botorch.sampling.base import MCSampler
from botorch.utils.types import _DefaultType
from pyre_extensions import assert_is_instance
from torch import Tensor


Expand Down Expand Up @@ -397,10 +398,13 @@ def transform_type_to_dict(transform_type: type[Transform]) -> dict[str, Any]:


def generation_step_to_dict(generation_step: GenerationStep) -> dict[str, Any]:
"""Converts Ax generation step to a dictionary."""
# pyre-fixme[6]: Currently, Pyre doesn't recognize that `Generation
# Step.__new__` actually returns a `GenerationNode`.
return generation_node_to_dict(generation_node=generation_step)
"""Converts Ax generation step to a dictionary.

Note: ``GenerationStep.__new__`` actually returns a ``GenerationNode``.
"""
return generation_node_to_dict(
generation_node=cast(GenerationNode, generation_step)
)


def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]:
Expand Down Expand Up @@ -582,14 +586,14 @@ def botorch_input_transform_to_init_args(
if isinstance(input_transform, ChainedInputTransform):
return {k: botorch_component_to_dict(v) for k, v in input_transform.items()}
else:
try:
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return input_transform.get_init_args()
except AttributeError:
if not hasattr(input_transform, "get_init_args"):
raise JSONEncodeError(
f"{input_transform.__class__.__name__} does not define `get_init_args` "
"method. Please implement it to enable storage."
)
# pyre-fixme[29]: `Union[Tensor, Module]` is not callable; hasattr guards
# this but pyre can't narrow the Union type.
return assert_is_instance(input_transform, InputTransform).get_init_args()


def percentile_early_stopping_strategy_to_dict(
Expand Down
8 changes: 2 additions & 6 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior


# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
CORE_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = {
CORE_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = {
Arm: arm_to_dict,
AuxiliaryExperiment: auxiliary_experiment_to_dict,
AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
Expand Down Expand Up @@ -269,9 +267,7 @@
# NOTE: Avoid putting a class along with its subclass in `CLASS_ENCODER_REGISTRY`.
# The encoder iterates through this dictionary and uses the first superclass that
# it finds, which might not be the intended superclass.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
CORE_CLASS_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = {
CORE_CLASS_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = {
Acquisition: botorch_modular_to_dict, # Ax MBM component
AcquisitionFunction: botorch_modular_to_dict, # BoTorch component
InputTransform: botorch_modular_to_dict, # BoTorch input transform component
Expand Down
8 changes: 2 additions & 6 deletions ax/storage/json_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@
def save_experiment(
experiment: Experiment,
filepath: str,
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
class_encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_CLASS_ENCODER_REGISTRY,
) -> None:
"""Save experiment to file.
Expand Down
8 changes: 2 additions & 6 deletions ax/storage/metric_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,13 @@

def register_metrics(
metric_clss: dict[type[Metric], int | None],
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
) -> tuple[
dict[type[Metric], int],
dict[type, Callable[[Any], dict[str, Any]]],
dict[type[Any], Callable[[Any], dict[str, Any]]],
TDecoderRegistry,
]:
"""Add custom metric classes to the SQA and JSON registries.
Expand Down
34 changes: 12 additions & 22 deletions ax/storage/registry_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,8 @@ def __init__(
self,
metric_clss: dict[type[Metric], int | None],
runner_clss: dict[type[Runner], int | None],
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]],
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]],
json_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]],
json_class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]],
json_decoder_registry: TDecoderRegistry,
json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]],
) -> None:
Expand All @@ -79,18 +75,18 @@ def __init__(
runner_clss = {
k: int(v) if v is not None else None for k, v in runner_clss.items()
}
# pyre-fixme[4]: Attribute must be annotated.
self._metric_registry: dict[type[Metric], int]
self._runner_registry: dict[type[Runner], int]
self._encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]]
self._decoder_registry: TDecoderRegistry
self._metric_registry, encoder_registry, decoder_registry = register_metrics(
metric_clss=metric_clss,
encoder_registry=json_encoder_registry,
decoder_registry=json_decoder_registry,
)
(
# pyre-fixme[4]: Attribute must be annotated.
self._runner_registry,
# pyre-fixme[4]: Attribute must be annotated.
self._encoder_registry,
# pyre-fixme[4]: Attribute must be annotated.
self._decoder_registry,
) = register_runners(
runner_clss=runner_clss,
Expand All @@ -110,19 +106,17 @@ def runner_registry(self) -> dict[type[Runner], int]:
return self._runner_registry

@property
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
def encoder_registry(self) -> dict[type, Callable[[Any], dict[str, Any]]]:
def encoder_registry(self) -> dict[type[Any], Callable[[Any], dict[str, Any]]]:
return self._encoder_registry

@property
def decoder_registry(self) -> TDecoderRegistry:
return self._decoder_registry

@property
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
def class_encoder_registry(self) -> dict[type, Callable[[Any], dict[str, Any]]]:
def class_encoder_registry(
self,
) -> dict[type[Any], Callable[[Any], dict[str, Any]]]:
return self._json_class_encoder_registry

@property
Expand Down Expand Up @@ -177,15 +171,11 @@ def __init__(
self,
metric_clss: dict[type[Metric], int | None],
runner_clss: dict[type[Runner], int | None],
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
json_encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
json_class_encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_CLASS_ENCODER_REGISTRY,
json_decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
json_class_decoder_registry: dict[
Expand Down
20 changes: 4 additions & 16 deletions ax/storage/runner_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,17 @@
CORE_RUNNER_REGISTRY: dict[type[Runner], int] = {SyntheticRunner: 0}


# pyre-fixme[3]: Return annotation cannot contain `Any`.
def register_runner(
runner_cls: type[Runner],
runner_registry: dict[type[Runner], int] = CORE_RUNNER_REGISTRY,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
val: int | None = None,
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
) -> tuple[
dict[type[Runner], int],
dict[type, Callable[[Any], dict[str, Any]]],
dict[type[Any], Callable[[Any], dict[str, Any]]],
TDecoderRegistry,
]:
"""Add a custom runner class to the SQA and JSON registries.
Expand All @@ -62,22 +56,16 @@ def register_runner(
return new_runner_registry, new_encoder_registry, new_decoder_registry


# pyre-fixme[3]: Return annotation cannot contain `Any`.
def register_runners(
runner_clss: dict[type[Runner], int | None],
runner_registry: dict[type[Runner], int] = CORE_RUNNER_REGISTRY,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
type[Any], Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
) -> tuple[
dict[type[Runner], int],
dict[type, Callable[[Any], dict[str, Any]]],
dict[type[Any], Callable[[Any], dict[str, Any]]],
TDecoderRegistry,
]:
"""Add custom runner classes to the SQA and JSON registries.
Expand Down
3 changes: 1 addition & 2 deletions ax/storage/sqa_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LONGTEXT_BYTES: int = 2**32 - 1

# global database variables
SESSION_FACTORY: Session | None = None
SESSION_FACTORY: scoped_session | None = None

# set this to false to prevent SQLAlchemy for automatically expiring objects
# on commit, which essentially makes them unusable outside of a session
Expand Down Expand Up @@ -235,7 +235,6 @@ def get_session() -> Session:
if SESSION_FACTORY is None:
init_engine_and_session_factory()
assert SESSION_FACTORY is not None
# pyre-fixme[29]: `Session` is not a function.
return SESSION_FACTORY()


Expand Down
Loading