From 8d58d9979e53a13e35c88120d0d7287d347664ec Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 6 Mar 2026 06:50:16 -0800 Subject: [PATCH 1/3] Remove pyre-fixme/pyre-ignore from ax/storage/ source files Summary: Remove ~124 pyre-fixme/pyre-ignore suppression comments from 22 source files in ax/storage/ by applying proper type fixes: - Use `cast(type[SQAClass], ...)` for SQA class lookups from config dicts - Use `cast(type[Enum], enum)` for enum value/name access - Change bare `type` to `type[Any]` in registry function signatures - Use `assert_is_instance()` for JSON dict key narrowing - Add proper type annotations for Generator return types - Use `none_throws()` for generation strategy ID access - Fix SQLAlchemy TypeDecorator parameter types Remaining pyre errors are pre-existing SQLAlchemy/BoTorch stub mismatches that cannot be fixed without changing library type stubs. Differential Revision: D95264795 --- ax/storage/json_store/decoder.py | 1 - ax/storage/json_store/decoders.py | 10 +++--- ax/storage/json_store/encoder.py | 6 ++-- ax/storage/json_store/encoders.py | 22 +++++++----- ax/storage/json_store/registry.py | 8 ++--- ax/storage/json_store/save.py | 8 ++--- ax/storage/metric_registry.py | 8 ++--- ax/storage/registry_bundle.py | 34 +++++++------------ ax/storage/runner_registry.py | 20 +++-------- ax/storage/sqa_store/db.py | 3 +- ax/storage/sqa_store/decoder.py | 48 ++++++++++----------------- ax/storage/sqa_store/delete.py | 16 +++++---- ax/storage/sqa_store/json.py | 37 +++++++++------------ ax/storage/sqa_store/load.py | 27 ++++++++------- ax/storage/sqa_store/reduced_state.py | 5 ++- ax/storage/sqa_store/save.py | 42 ++++++++++++----------- ax/storage/sqa_store/sqa_classes.py | 7 ++-- ax/storage/sqa_store/sqa_enum.py | 10 ++---- ax/storage/sqa_store/timestamp.py | 2 +- 19 files changed, 131 insertions(+), 183 deletions(-) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index d220bc56d25..3cc658b51e3 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -474,7 +474,6 @@ def _criterion_from_json( for key, value in object_json.items() } init_args = extract_init_args(args=decoded, class_=criterion_class) - # pyre-ignore[45]: Class passed is always a concrete subclass. return criterion_class(**init_args) diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index 2cc7a011eef..accf67e599f 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -454,8 +454,9 @@ def choice_parameter_from_json( # JSON converts dictionary keys to strings. We need to convert them back. if dependents is not None: dependents = { - # pyre-ignore [6]: JSON keys are always strings - string_to_parameter_value(s=key, parameter_type=parameter_type): value + string_to_parameter_value( + s=assert_is_instance(key, str), parameter_type=parameter_type + ): value for key, value in dependents.items() } @@ -499,8 +500,9 @@ def fixed_parameter_from_json( # JSON converts dictionary keys to strings. We need to convert them back. if dependents is not None: dependents = { - # pyre-ignore [6]: JSON keys are always strings - string_to_parameter_value(s=key, parameter_type=parameter_type): value + string_to_parameter_value( + s=assert_is_instance(key, str), parameter_type=parameter_type + ): value for key, value in dependents.items() } diff --git a/ax/storage/json_store/encoder.py b/ax/storage/json_store/encoder.py index 7e688debf84..18a0130178c 100644 --- a/ax/storage/json_store/encoder.py +++ b/ax/storage/json_store/encoder.py @@ -31,13 +31,11 @@ def object_to_json( obj: Any, - # pyre-ignore[24]: Missing parameter annotation, Invalid type parameters encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, - # pyre-ignore[24]: Missing parameter annotation, Invalid type parameters class_encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_CLASS_ENCODER_REGISTRY, ) -> Any: """Convert an Ax object to a JSON-serializable dictionary. diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 522eef89658..0d49f1abe20 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -8,7 +8,7 @@ import warnings from pathlib import Path -from typing import Any +from typing import Any, cast from ax.adapter.transforms.base import Transform from ax.core import Experiment, ObservationFeatures @@ -74,6 +74,7 @@ from botorch.models.transforms.input import ChainedInputTransform, InputTransform from botorch.sampling.base import MCSampler from botorch.utils.types import _DefaultType +from pyre_extensions import assert_is_instance from torch import Tensor @@ -397,10 +398,13 @@ def transform_type_to_dict(transform_type: type[Transform]) -> dict[str, Any]: def generation_step_to_dict(generation_step: GenerationStep) -> dict[str, Any]: - """Converts Ax generation step to a dictionary.""" - # pyre-fixme[6]: Currently, Pyre doesn't recognize that `Generation - # Step.__new__` actually returns a `GenerationNode`. - return generation_node_to_dict(generation_node=generation_step) + """Converts Ax generation step to a dictionary. + + Note: ``GenerationStep.__new__`` actually returns a ``GenerationNode``. + """ + return generation_node_to_dict( + generation_node=cast(GenerationNode, generation_step) + ) def generation_node_to_dict(generation_node: GenerationNode) -> dict[str, Any]: @@ -582,14 +586,14 @@ def botorch_input_transform_to_init_args( if isinstance(input_transform, ChainedInputTransform): return {k: botorch_component_to_dict(v) for k, v in input_transform.items()} else: - try: - # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. - return input_transform.get_init_args() - except AttributeError: + if not hasattr(input_transform, "get_init_args"): raise JSONEncodeError( f"{input_transform.__class__.__name__} does not define `get_init_args` " "method. Please implement it to enable storage." ) + # pyre-fixme[29]: `Union[Tensor, Module]` is not callable; hasattr guards + # this but pyre can't narrow the Union type. + return assert_is_instance(input_transform, InputTransform).get_init_args() def percentile_early_stopping_strategy_to_dict( diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index cd7bc584d51..3a7f4fcec97 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -184,9 +184,7 @@ from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior -# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to -# avoid runtime subscripting errors. -CORE_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = { +CORE_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = { Arm: arm_to_dict, AuxiliaryExperiment: auxiliary_experiment_to_dict, AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict, @@ -269,9 +267,7 @@ # NOTE: Avoid putting a class along with its subclass in `CLASS_ENCODER_REGISTRY`. # The encoder iterates through this dictionary and uses the first superclass that # it finds, which might not be the intended superclass. -# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to -# avoid runtime subscripting errors. -CORE_CLASS_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = { +CORE_CLASS_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = { Acquisition: botorch_modular_to_dict, # Ax MBM component AcquisitionFunction: botorch_modular_to_dict, # BoTorch component InputTransform: botorch_modular_to_dict, # BoTorch input transform component diff --git a/ax/storage/json_store/save.py b/ax/storage/json_store/save.py index 2fc1244bd8d..3519b2a6ef1 100644 --- a/ax/storage/json_store/save.py +++ b/ax/storage/json_store/save.py @@ -21,15 +21,11 @@ def save_experiment( experiment: Experiment, filepath: str, - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. class_encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_CLASS_ENCODER_REGISTRY, ) -> None: """Save experiment to file. diff --git a/ax/storage/metric_registry.py b/ax/storage/metric_registry.py index 1a7ef4afa45..bc8f93f0d99 100644 --- a/ax/storage/metric_registry.py +++ b/ax/storage/metric_registry.py @@ -50,17 +50,13 @@ def register_metrics( metric_clss: dict[type[Metric], int | None], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to - # avoid runtime subscripting errors. ) -> tuple[ dict[type[Metric], int], - dict[type, Callable[[Any], dict[str, Any]]], + dict[type[Any], Callable[[Any], dict[str, Any]]], TDecoderRegistry, ]: """Add custom metric classes to the SQA and JSON registries. diff --git a/ax/storage/registry_bundle.py b/ax/storage/registry_bundle.py index 40ed81500d6..dbde9930738 100644 --- a/ax/storage/registry_bundle.py +++ b/ax/storage/registry_bundle.py @@ -64,12 +64,8 @@ def __init__( self, metric_clss: dict[type[Metric], int | None], runner_clss: dict[type[Runner], int | None], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]], + json_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]], + json_class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]], json_decoder_registry: TDecoderRegistry, json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]], ) -> None: @@ -79,18 +75,18 @@ def __init__( runner_clss = { k: int(v) if v is not None else None for k, v in runner_clss.items() } - # pyre-fixme[4]: Attribute must be annotated. + self._metric_registry: dict[type[Metric], int] + self._runner_registry: dict[type[Runner], int] + self._encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] + self._decoder_registry: TDecoderRegistry self._metric_registry, encoder_registry, decoder_registry = register_metrics( metric_clss=metric_clss, encoder_registry=json_encoder_registry, decoder_registry=json_decoder_registry, ) ( - # pyre-fixme[4]: Attribute must be annotated. self._runner_registry, - # pyre-fixme[4]: Attribute must be annotated. self._encoder_registry, - # pyre-fixme[4]: Attribute must be annotated. self._decoder_registry, ) = register_runners( runner_clss=runner_clss, @@ -110,9 +106,7 @@ def runner_registry(self) -> dict[type[Runner], int]: return self._runner_registry @property - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - def encoder_registry(self) -> dict[type, Callable[[Any], dict[str, Any]]]: + def encoder_registry(self) -> dict[type[Any], Callable[[Any], dict[str, Any]]]: return self._encoder_registry @property @@ -120,9 +114,9 @@ def decoder_registry(self) -> TDecoderRegistry: return self._decoder_registry @property - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - def class_encoder_registry(self) -> dict[type, Callable[[Any], dict[str, Any]]]: + def class_encoder_registry( + self, + ) -> dict[type[Any], Callable[[Any], dict[str, Any]]]: return self._json_class_encoder_registry @property @@ -177,15 +171,11 @@ def __init__( self, metric_clss: dict[type[Metric], int | None], runner_clss: dict[type[Runner], int | None], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. json_encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. json_class_encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_CLASS_ENCODER_REGISTRY, json_decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, json_class_decoder_registry: dict[ diff --git a/ax/storage/runner_registry.py b/ax/storage/runner_registry.py index 12803e2de9c..a78fa4388b0 100644 --- a/ax/storage/runner_registry.py +++ b/ax/storage/runner_registry.py @@ -31,23 +31,17 @@ CORE_RUNNER_REGISTRY: dict[type[Runner], int] = {SyntheticRunner: 0} -# pyre-fixme[3]: Return annotation cannot contain `Any`. def register_runner( runner_cls: type[Runner], runner_registry: dict[type[Runner], int] = CORE_RUNNER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, val: int | None = None, - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to - # avoid runtime subscripting errors. ) -> tuple[ dict[type[Runner], int], - dict[type, Callable[[Any], dict[str, Any]]], + dict[type[Any], Callable[[Any], dict[str, Any]]], TDecoderRegistry, ]: """Add a custom runner class to the SQA and JSON registries. @@ -62,22 +56,16 @@ def register_runner( return new_runner_registry, new_encoder_registry, new_decoder_registry -# pyre-fixme[3]: Return annotation cannot contain `Any`. def register_runners( runner_clss: dict[type[Runner], int | None], runner_registry: dict[type[Runner], int] = CORE_RUNNER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. encoder_registry: dict[ - type, Callable[[Any], dict[str, Any]] + type[Any], Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to - # avoid runtime subscripting errors. ) -> tuple[ dict[type[Runner], int], - dict[type, Callable[[Any], dict[str, Any]]], + dict[type[Any], Callable[[Any], dict[str, Any]]], TDecoderRegistry, ]: """Add custom runner classes to the SQA and JSON registries. diff --git a/ax/storage/sqa_store/db.py b/ax/storage/sqa_store/db.py index 9da5d98e4e2..a22122648f7 100644 --- a/ax/storage/sqa_store/db.py +++ b/ax/storage/sqa_store/db.py @@ -33,7 +33,7 @@ LONGTEXT_BYTES: int = 2**32 - 1 # global database variables -SESSION_FACTORY: Session | None = None +SESSION_FACTORY: scoped_session | None = None # set this to false to prevent SQLAlchemy for automatically expiring objects # on commit, which essentially makes them unusable outside of a session @@ -235,7 +235,6 @@ def get_session() -> Session: if SESSION_FACTORY is None: init_engine_and_session_factory() assert SESSION_FACTORY is not None - # pyre-fixme[29]: `Session` is not a function. return SESSION_FACTORY() diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index a4a68989ff4..ec138e36ee0 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -56,6 +56,7 @@ from ax.core.search_space import SearchSpace from ax.core.trial import Trial from ax.core.trial_status import TrialStatus +from ax.core.types import TModelPredict, TModelPredictArm from ax.exceptions.storage import JSONDecodeError, SQADecodeError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.json_store.decoder import _DEPRECATED_GENERATOR_KWARGS, object_from_json @@ -135,7 +136,7 @@ def get_enum_name( return None try: - return enum(value).name # pyre-ignore T29651755 + return cast(type[Enum], enum)(value).name except ValueError: raise SQADecodeError(f"Value {value} is invalid for enum {enum}.") @@ -304,7 +305,7 @@ def _init_mt_experiment_from_sqa( ) default_trial_type = none_throws(experiment_sqa.default_trial_type) - trial_type_to_runner = { + trial_type_to_runner: dict[str, Runner | None] = { none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner) for sqa_runner in experiment_sqa.runners } @@ -330,9 +331,8 @@ def _init_mt_experiment_from_sqa( status_quo=status_quo, properties=properties, ) - # pyre-ignore Imcompatible attribute type [8]: attribute _trial_type_to_runner - # has type Dict[str, Optional[Runner]] but is used as type - # Uniont[Dict[str, Optional[Runner]], Dict[str, None]] + # pyre-fixme[8]: `_trial_type_to_runner` expects `Dict[Optional[str], + # Optional[Runner]]` but the dict built here uses `str` keys. experiment._trial_type_to_runner = trial_type_to_runner sqa_metric_dict = {metric.name: metric for metric in experiment_sqa.metrics} for tracking_metric in tracking_metrics: @@ -730,8 +730,8 @@ def generator_run_from_sqa( opt_config = None search_space = None - best_arm_predictions = None - model_predictions = None + best_arm_predictions: tuple[Arm, TModelPredictArm | None] | None = None + model_predictions: TModelPredict | None = None if ( generator_run_sqa.best_arm_parameters is not None and generator_run_sqa.best_arm_predictions is not None @@ -740,15 +740,14 @@ def generator_run_from_sqa( name=generator_run_sqa.best_arm_name, parameters=none_throws(generator_run_sqa.best_arm_parameters), ) + raw_predictions = none_throws(generator_run_sqa.best_arm_predictions) best_arm_predictions = ( best_arm, - tuple(none_throws(generator_run_sqa.best_arm_predictions)), + cast(TModelPredictArm, tuple(raw_predictions)), ) - model_predictions = ( - tuple(none_throws(generator_run_sqa.model_predictions)) - if generator_run_sqa.model_predictions is not None - else None - ) + if generator_run_sqa.model_predictions is not None: + raw_model_predictions = none_throws(generator_run_sqa.model_predictions) + model_predictions = cast(TModelPredict, tuple(raw_model_predictions)) generator_run = GeneratorRun( arms=arms, @@ -765,11 +764,7 @@ def generator_run_from_sqa( if generator_run_sqa.gen_time is None else float(generator_run_sqa.gen_time) ), - best_arm_predictions=best_arm_predictions, # pyre-ignore[6] - # pyre-fixme[6]: Expected `Optional[Tuple[typing.Dict[str, List[float]], - # typing.Dict[str, typing.Dict[str, List[float]]]]]` for 8th param but got - # `Optional[typing.Tuple[Union[typing.Dict[str, List[float]], - # typing.Dict[str, typing.Dict[str, List[float]]]], ...]]`. + best_arm_predictions=best_arm_predictions, model_predictions=model_predictions, generator_key=generator_run_sqa.model_key, generator_kwargs=( @@ -933,7 +928,8 @@ def runner_from_sqa(self, runner_sqa: SQARunner) -> Runner: decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ) - # pyre-ignore[45]: Cannot instantiate abstract class `Runner`. + # pyre-fixme[45]: `runner_class` is always a concrete subclass at runtime, + # but pyre sees `Runner` (abstract) from the reverse_runner_registry type. runner = runner_class(**args) runner.db_id = runner_sqa.id return runner @@ -1012,21 +1008,11 @@ def trial_from_sqa( trial._time_staged = trial_sqa.time_staged trial._time_run_started = trial_sqa.time_run_started trial._status_reason = trial_sqa.abandoned_reason or trial_sqa.failed_reason - # pyre-fixme[9]: _run_metadata has type `Dict[str, Any]`; used as - # `Optional[Dict[str, Any]]`. - # pyre-fixme[8]: Attribute has type `Dict[str, typing.Any]`; used as - # `Optional[typing.Dict[Variable[_KT], Variable[_VT]]]`. trial._run_metadata = ( - dict(trial_sqa.run_metadata) if trial_sqa.run_metadata is not None else None + dict(trial_sqa.run_metadata) if trial_sqa.run_metadata is not None else {} ) - # pyre-fixme[9]: _run_metadata has type `Dict[str, Any]`; used as - # `Optional[Dict[str, Any]]`. - # pyre-fixme[8]: Attribute has type `Dict[str, typing.Any]`; used as - # `Optional[typing.Dict[Variable[_KT], Variable[_VT]]]`. trial._stop_metadata = ( - dict(trial_sqa.stop_metadata) - if trial_sqa.stop_metadata is not None - else None + dict(trial_sqa.stop_metadata) if trial_sqa.stop_metadata is not None else {} ) trial._num_arms_created = trial_sqa.num_arms_created trial._properties = dict(trial_sqa.properties or {}) diff --git a/ax/storage/sqa_store/delete.py b/ax/storage/sqa_store/delete.py index 00e08977aa2..ebf67ceea25 100644 --- a/ax/storage/sqa_store/delete.py +++ b/ax/storage/sqa_store/delete.py @@ -6,12 +6,13 @@ # pyre-strict from logging import Logger +from typing import cast from ax.core.experiment import Experiment from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.sqa_store.db import session_scope from ax.storage.sqa_store.decoder import Decoder -from ax.storage.sqa_store.sqa_classes import SQAExperiment +from ax.storage.sqa_store.sqa_classes import SQAExperiment, SQAGenerationStrategy from ax.storage.sqa_store.sqa_config import SQAConfig from ax.utils.common.logger import get_logger @@ -61,12 +62,13 @@ def delete_generation_strategy( exp_sqa_class = decoder.config.class_to_sqa_class[Experiment] gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy] # get the generation strategy's db_id + gs_sqa_class_typed = cast(type[SQAGenerationStrategy], gs_sqa_class) + exp_sqa_class_typed = cast(type[SQAExperiment], exp_sqa_class) with session_scope() as session: sqa_gs_ids = ( - session.query(gs_sqa_class.id) # pyre-ignore[16] - .join(exp_sqa_class.generation_strategy) # pyre-ignore[16] - # pyre-fixme[16]: `SQABase` has no attribute `name`. - .filter(exp_sqa_class.name == exp_name) + session.query(gs_sqa_class_typed.id) + .join(exp_sqa_class_typed.generation_strategy) + .filter(exp_sqa_class_typed.name == exp_name) .all() ) @@ -84,8 +86,8 @@ def delete_generation_strategy( # delete generation strategy with session_scope() as session: gs_list = ( - session.query(gs_sqa_class) - .filter(gs_sqa_class.id.in_([id[0] for id in sqa_gs_ids])) + session.query(gs_sqa_class_typed) + .filter(gs_sqa_class_typed.id.in_([id[0] for id in sqa_gs_ids])) .all() ) for gs in gs_list: diff --git a/ax/storage/sqa_store/json.py b/ax/storage/sqa_store/json.py index 9d09eaca96b..5cbb682050c 100644 --- a/ax/storage/sqa_store/json.py +++ b/ax/storage/sqa_store/json.py @@ -30,27 +30,24 @@ class JSONEncodedObject(TypeDecorator): def __init__( self, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - object_pairs_hook: Any = None, + object_pairs_hook: type[Any] | None = None, *args: list[Any], **kwargs: dict[Any, Any], ) -> None: - # pyre-fixme[4]: Attribute annotation cannot be `Any`. - self.object_pairs_hook: Any = object_pairs_hook + self.object_pairs_hook: type[Any] | None = object_pairs_hook super().__init__(*args, **kwargs) - # pyre-fixme[2]: Parameter annotation cannot be `Any`. def process_bind_param(self, value: Any, dialect: Any) -> str | None: if value is not None: return json.dumps(value) else: return None - # pyre-fixme[3]: Return annotation cannot be `Any`. - # pyre-fixme[2]: Parameter annotation cannot be `Any`. def process_result_value(self, value: Any, dialect: Any) -> Any: if value is not None: try: # TODO T61331534: revert this; just a hotfix for AutoML + # pyre-fixme[6]: `object_pairs_hook` expects a callable but + # `type[Any] | None` is stored; compatible at runtime. return json.loads(value, object_pairs_hook=self.object_pairs_hook) except JSONDecodeError: return None @@ -65,8 +62,8 @@ class JSONEncodedText(JSONEncodedObject): """ - # pyre-fixme[15]: `impl` overrides attribute defined in `JSONEncodedObject` - # inconsistently. + # pyre-fixme[15]: `impl` overrides attribute in `JSONEncodedObject` with + # incompatible type; SQLAlchemy allows broader `impl` types at runtime. impl = Text @@ -78,29 +75,25 @@ class JSONEncodedMediumText(JSONEncodedObject): """ - # pyre-fixme[15]: `impl` overrides attribute defined in `JSONEncodedObject` - # inconsistently. + # pyre-fixme[15]: `impl` overrides attribute in `JSONEncodedObject` with + # incompatible type; SQLAlchemy allows broader `impl` types at runtime. impl = Text(MEDIUMTEXT_BYTES) class JSONEncodedLongText(JSONEncodedObject): - """Class for JSON-encoding objects in SQLAlchemy, backed by MEDIUMTEXT + """Class for JSON-encoding objects in SQLAlchemy, backed by LONGTEXT (MySQL). See description in JSONEncodedObject. """ - # pyre-fixme[15]: `impl` overrides attribute defined in `JSONEncodedObject` - # inconsistently. + # pyre-fixme[15]: `impl` overrides attribute in `JSONEncodedObject` with + # incompatible type; SQLAlchemy allows broader `impl` types at runtime. impl = Text(LONGTEXT_BYTES) -# pyre-fixme[5]: Global expression must be annotated. -JSONEncodedList = MutableList.as_mutable(JSONEncodedObject) -# pyre-fixme[5]: Global expression must be annotated. -JSONEncodedDict = MutableDict.as_mutable(JSONEncodedObject) -# pyre-fixme[5]: Global expression must be annotated. -JSONEncodedTextDict = MutableDict.as_mutable(JSONEncodedText) -# pyre-fixme[5]: Global expression must be annotated. -JSONEncodedLongTextDict = MutableDict.as_mutable(JSONEncodedLongText) +JSONEncodedList: TypeDecorator = MutableList.as_mutable(JSONEncodedObject) +JSONEncodedDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedObject) +JSONEncodedTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedText) +JSONEncodedLongTextDict: TypeDecorator = MutableDict.as_mutable(JSONEncodedLongText) diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index 2c6d0f7d146..d271c5cfd7d 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -419,10 +419,10 @@ def _get_experiment_id(experiment_name: str, config: SQAConfig) -> int | None: """Get DB ID of the experiment by the given name if its in DB, return None otherwise. """ - exp_sqa_class = config.class_to_sqa_class[Experiment] + exp_sqa_class = cast(type[SQAExperiment], config.class_to_sqa_class[Experiment]) with session_scope() as session: sqa_experiment_id = ( - session.query(exp_sqa_class.id) # pyre-ignore + session.query(exp_sqa_class.id) .filter_by(name=experiment_name) .one_or_none() ) @@ -522,11 +522,7 @@ def _load_generation_strategy_by_id( ( _get_experiment_immutable_opt_config_and_search_space( experiment_name=experiment.name, - # pyre-ignore Incompatible parameter type [6]: Expected - # `Type[SQAExperiment]` for 2nd parameter `exp_sqa_class` - # to call `_get_experiment_immutable_opt_config_and_search_space` - # but got `Type[ax.storage.sqa_store.db.SQABase]`. - exp_sqa_class=exp_sqa_class, + exp_sqa_class=cast(type[SQAExperiment], exp_sqa_class), ) ) if experiment is not None @@ -553,13 +549,14 @@ def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> int | """ exp_sqa_class = decoder.config.class_to_sqa_class[Experiment] gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy] + gs_sqa_class_typed = cast(type[SQAGenerationStrategy], gs_sqa_class) + exp_sqa_class_typed = cast(type[SQAExperiment], exp_sqa_class) with session_scope() as session: sqa_gs_ids = ( - session.query(gs_sqa_class.id) # pyre-ignore[16] - .join(exp_sqa_class.generation_strategy) # pyre-ignore[16] - # pyre-fixme[16]: `SQABase` has no attribute `name`. - .filter(exp_sqa_class.name == experiment_name) - .order_by(gs_sqa_class.id.desc()) + session.query(gs_sqa_class_typed.id) + .join(exp_sqa_class_typed.generation_strategy) + .filter(exp_sqa_class_typed.name == experiment_name) + .order_by(gs_sqa_class_typed.id.desc()) .all() ) @@ -649,10 +646,12 @@ def get_generator_runs_by_id( immutable_search_space_and_opt_config: bool = False, ) -> list[GeneratorRun]: """Bulk fetches generator runs by id.""" - generator_run_sqa_class = decoder.config.class_to_sqa_class[GeneratorRun] + generator_run_sqa_class = cast( + type[SQAGeneratorRun], decoder.config.class_to_sqa_class[GeneratorRun] + ) with session_scope() as session: query = session.query(generator_run_sqa_class).filter( - generator_run_sqa_class.id.in_(generator_run_ids) # pyre-ignore[16] + generator_run_sqa_class.id.in_(generator_run_ids) ) sqa_grs = query.all() return [ diff --git a/ax/storage/sqa_store/reduced_state.py b/ax/storage/sqa_store/reduced_state.py index 0117d2a52b3..39cdf76780e 100644 --- a/ax/storage/sqa_store/reduced_state.py +++ b/ax/storage/sqa_store/reduced_state.py @@ -12,7 +12,10 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute -GR_LARGE_MODEL_ATTRS: list[InstrumentedAttribute] = [ # pyre-ignore[9] +# pyre-fixme[9]: `GR_LARGE_MODEL_ATTRS` is declared as `List[InstrumentedAttribute]` +# but SQLAlchemy class attributes are typed as `Column` in stubs; they are +# `InstrumentedAttribute` instances at runtime. +GR_LARGE_MODEL_ATTRS: list[InstrumentedAttribute] = [ SQAGeneratorRun.model_kwargs, SQAGeneratorRun.bridge_kwargs, SQAGeneratorRun.model_state_after_gen, diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index 437855ad1ee..16e60add3d9 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -7,7 +7,7 @@ # pyre-strict import os -from collections.abc import Callable, Sequence +from collections.abc import Callable, Generator, Sequence from logging import Logger from typing import Any, cast, Type @@ -32,6 +32,7 @@ from ax.storage.sqa_store.encoder import Encoder from ax.storage.sqa_store.sqa_classes import ( SQAData, + SQAExperiment, SQAGeneratorRun, SQAMetric, SQARunner, @@ -97,16 +98,20 @@ def _save_experiment( existing SQLAlchemy object, and then letting SQLAlchemy handle the actual DB updates. """ - exp_sqa_class = encoder.config.class_to_sqa_class[Experiment] + exp_sqa_class = cast( + Type[SQAExperiment], encoder.config.class_to_sqa_class[Experiment] + ) with session_scope() as session: - existing_sqa_experiment_id = ( - # pyre-ignore Undefined attribute [16]: `SQABase` has no attribute `id` + existing_sqa_experiment_id_row = ( session.query(exp_sqa_class.id) .filter_by(name=experiment.name) .one_or_none() ) - if existing_sqa_experiment_id: - existing_sqa_experiment_id = existing_sqa_experiment_id[0] + existing_sqa_experiment_id: int | None = ( + existing_sqa_experiment_id_row[0] + if existing_sqa_experiment_id_row is not None + else None + ) encoder.validate_experiment_metadata( experiment, @@ -422,9 +427,10 @@ def _update_generation_strategy( """Update generation strategy's current step and attach generator runs.""" gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy] - gs_id = generation_strategy.db_id - if gs_id is None: - raise ValueError("GenerationStrategy must be saved before being updated.") + gs_id: int = none_throws( + generation_strategy.db_id, + "GenerationStrategy must be saved before being updated.", + ) experiment_id = generation_strategy.experiment.db_id if experiment_id is None: @@ -444,13 +450,12 @@ def _update_generation_strategy( } ) - # pyre-fixme[53]: Captured variable `gs_id` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - def add_generation_strategy_id(sqa: SQAGeneratorRun): + def add_generation_strategy_id(sqa: SQAGeneratorRun) -> None: sqa.generation_strategy_id = gs_id - # pyre-fixme[3]: Return type must be annotated. - def generator_run_to_sqa_encoder(gr: GeneratorRun, weight: float | None = None): + def generator_run_to_sqa_encoder( + gr: GeneratorRun, weight: float | None = None + ) -> SQAGeneratorRun: return encoder.generator_run_to_sqa( gr, weight=weight, @@ -483,8 +488,7 @@ def update_runner_on_experiment( with session_scope() as session: session.query(runner_sqa_class).filter_by(experiment_id=exp_id).delete() - # pyre-fixme[3]: Return type must be annotated. - def add_experiment_id(sqa: SQARunner): + def add_experiment_id(sqa: SQARunner) -> None: sqa.experiment_id = exp_id _merge_into_session( @@ -682,9 +686,9 @@ def _bulk_merge_into_session( sqas.append(sqa) # https://stackoverflow.com/a/312464 - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def split_into_batches(lst, n): + def split_into_batches( + lst: list[SQABase], n: int + ) -> Generator[list[SQABase], None, None]: for i in range(0, len(lst), n): yield lst[i : i + n] diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 1176634b9e2..ee656e8483a 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -384,11 +384,8 @@ class SQAExperiment(Base): ) default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True) - # pyre-fixme[8]: Incompatible attribute type [8]: Attribute - # `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has - # type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]` - auxiliary_experiments_by_purpose: dict[str, list[dict[str, Any]]] | None = Column( - JSONEncodedTextDict, nullable=True, default={} + auxiliary_experiments_by_purpose: Column[dict[str, list[dict[str, Any]]] | None] = ( + Column(JSONEncodedTextDict, nullable=True, default={}) ) # relationships diff --git a/ax/storage/sqa_store/sqa_enum.py b/ax/storage/sqa_store/sqa_enum.py index ca7d202cc9c..52286e5ec68 100644 --- a/ax/storage/sqa_store/sqa_enum.py +++ b/ax/storage/sqa_store/sqa_enum.py @@ -18,10 +18,8 @@ class BaseNullableEnum(types.TypeDecorator): def __init__(self, enum: Any, *arg: list[Any], **kw: dict[Any, Any]) -> None: types.TypeDecorator.__init__(self, *arg, **kw) - # pyre-fixme[4]: Attribute must be annotated. - self._member_map = enum._member_map_ - # pyre-fixme[4]: Attribute must be annotated. - self._value2member_map = enum._value2member_map_ + self._member_map: dict[str, Any] = enum._member_map_ + self._value2member_map: dict[Any, Any] = enum._value2member_map_ def process_bind_param(self, value: Any, dialect: Any) -> Any: if value is None: @@ -52,9 +50,7 @@ def process_result_value(self, value: Any, dialect: Any) -> Any: class IntEnum(BaseNullableEnum): - # pyre-fixme[8]: Attribute has type `SmallInteger`; used as - # `Type[sqlalchemy.sql.sqltypes.SmallInteger]`. - impl: types.SmallInteger = types.SmallInteger + impl = types.SmallInteger class StringEnum(BaseNullableEnum): diff --git a/ax/storage/sqa_store/timestamp.py b/ax/storage/sqa_store/timestamp.py index d3036002dad..481f10a9cca 100644 --- a/ax/storage/sqa_store/timestamp.py +++ b/ax/storage/sqa_store/timestamp.py @@ -17,7 +17,7 @@ class IntTimestamp(TypeDecorator): cache_ok = True # pyre-fixme[15]: `process_bind_param` overrides method defined in - # `TypeDecorator` inconsistently. + # `TypeDecorator` inconsistently; returns `int | None` vs `str | None`. def process_bind_param( self, value: datetime.datetime | None, dialect: Dialect ) -> int | None: From 463604ab564a9f8579c28a7b0f0c38c13e0e0e67 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 6 Mar 2026 06:50:16 -0800 Subject: [PATCH 2/3] Remove pyre-fixme/pyre-ignore from ax/core, ax/adapter, ax/generators test files Summary: Remove pyre-fixme and pyre-ignore type suppression comments from test files in ax/core/tests, ax/adapter/tests, ax/adapter/transforms/tests, and source file ax/adapter/transforms/one_hot.py. Uses proper type narrowing via none_throws, assert_is_instance, cast, and explicit type annotations instead of suppression comments. Key changes: - Replace `# pyre-ignore[16]` on `Parameter` attribute access with `assert_is_instance(..., RangeParameter)` / `ChoiceParameter` / `FixedParameter` - Replace `# pyre-fixme[16]` on Optional access with `none_throws(...)` - Add explicit type annotations (`TParameterization`, `TConfig`, `list[float]`, `dict[str, float | int]`) to fix type inference issues - Replace `**attrs` dict unpacking with explicit kwargs to eliminate union-type pyre errors in test_observation.py - Fix `all()` generator expression scoping bug in test_batch_trial.py (missing parentheses caused pyre-fixme[6]) - Remove unnecessary `return` statements inside `assertRaises` blocks - Add missing return type and parameter annotations on mock-decorated test methods - Refactor BoTorchGenerator construction in test_cross_validation.py to avoid pyre-ignore on `adapter.generator.surrogate` access Differential Revision: D95273495 --- ax/adapter/tests/test_base_adapter.py | 36 +++- ax/adapter/tests/test_cross_validation.py | 35 ++-- ax/adapter/tests/test_prediction_utils.py | 9 +- ax/adapter/tests/test_random_adapter.py | 11 +- ax/adapter/tests/test_torch_adapter.py | 2 +- ax/adapter/tests/test_torch_moo_adapter.py | 3 +- ax/adapter/tests/test_utils.py | 3 - .../transforms/tests/test_base_transform.py | 5 +- .../transforms/tests/test_cast_transform.py | 11 +- .../tests/test_choice_encode_transform.py | 5 +- .../test_int_range_to_choice_transform.py | 2 +- .../transforms/tests/test_logit_transform.py | 16 +- .../tests/test_metrics_as_task_transform.py | 7 +- .../tests/test_one_hot_transform.py | 17 +- .../tests/test_task_encode_transform.py | 6 +- ax/core/tests/test_batch_trial.py | 44 ++-- ax/core/tests/test_generator_run.py | 23 +-- ax/core/tests/test_objective.py | 6 +- ax/core/tests/test_observation.py | 190 +++++++++--------- ax/core/tests/test_outcome_constraint.py | 3 +- ax/core/tests/test_parameter_constraint.py | 6 +- ax/core/tests/test_runner.py | 3 +- ax/core/tests/test_trial.py | 13 +- ax/core/tests/test_utils.py | 7 +- 24 files changed, 225 insertions(+), 238 deletions(-) diff --git a/ax/adapter/tests/test_base_adapter.py b/ax/adapter/tests/test_base_adapter.py index 97716a13993..6e4e5eed757 100644 --- a/ax/adapter/tests/test_base_adapter.py +++ b/ax/adapter/tests/test_base_adapter.py @@ -76,7 +76,7 @@ from botorch.exceptions.warnings import InputDataWarning from botorch.models.utils.assorted import validate_input_scaling from pandas.testing import assert_frame_equal -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws ADAPTER__GEN_PATH: str = "ax.adapter.base.Adapter._gen" @@ -908,8 +908,14 @@ def test_set_model_space(self) -> None: .index.get_level_values("arm_name") ) self.assertEqual(set(ood_arms), {"status_quo", "custom"}) - self.assertEqual(m.model_space.parameters["x1"].lower, -5.0) # pyre-ignore[16] - self.assertEqual(m.model_space.parameters["x2"].upper, 15.0) # pyre-ignore[16] + self.assertEqual( + assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower, + -5.0, + ) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 15.0, + ) self.assertEqual(len(m.model_space.parameter_constraints), 1) # With expand model space, custom is not OOD, and model space is expanded @@ -925,8 +931,14 @@ def test_set_model_space(self) -> None: .index.get_level_values("arm_name") ) self.assertEqual(set(ood_arms), {"status_quo"}) - self.assertEqual(m.model_space.parameters["x1"].lower, -20.0) - self.assertEqual(m.model_space.parameters["x2"].upper, 18.0) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower, + -20.0, + ) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 18.0, + ) self.assertEqual(m.model_space.parameter_constraints, []) # With fill values, SQ is also in design, and x2 is further expanded @@ -941,7 +953,10 @@ def test_set_model_space(self) -> None: transform_configs={"FillMissingParameters": {"fill_values": sq_vals}}, ) self.assertEqual(sum(m.training_in_design), 7) - self.assertEqual(m.model_space.parameters["x2"].upper, 20) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 20, + ) self.assertEqual(m.model_space.parameter_constraints, []) # Using parameter backfill values @@ -955,7 +970,10 @@ def test_set_model_space(self) -> None: search_space=ss, ) self.assertEqual(sum(m.training_in_design), 7) - self.assertEqual(m.model_space.parameters["x2"].upper, 20) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 20, + ) self.assertEqual(m.model_space.parameter_constraints, []) # Check log scale expansion with OOD trial having parameter value == 0 @@ -992,12 +1010,12 @@ def test_set_model_space(self) -> None: # Assert that the expanded model space did not include 0.0 self.assertEqual( - m.model_space.parameters["x1"].lower, + assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower, 0.0001, ) # x2 model space should still be expanded self.assertEqual( - m.model_space.parameters["x2"].upper, + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, 2.0, ) diff --git a/ax/adapter/tests/test_cross_validation.py b/ax/adapter/tests/test_cross_validation.py index 55908746e1d..5001bdbf608 100644 --- a/ax/adapter/tests/test_cross_validation.py +++ b/ax/adapter/tests/test_cross_validation.py @@ -9,6 +9,7 @@ import warnings from collections.abc import Iterable from itertools import product +from typing import cast from unittest import mock import numpy as np @@ -70,6 +71,7 @@ from gpytorch.distributions import MultivariateNormal from linear_operator.operators import DiagLinearOperator from pandas import DataFrame +from pyre_extensions import assert_is_instance # Number of in-design points created by _create_adapter_with_out_of_design_points() _OOD_ADAPTER_IN_DESIGN_COUNT = 3 @@ -78,9 +80,8 @@ class CrossValidationTest(TestCase): def setUp(self) -> None: super().setUp() - # pyre-ignore [9] Pyre is too picky with union types. parameterizations: list[TParameterization] = [ - {"x": x} for x in [2.0, 2.0, 3.0, 4.0] + cast(TParameterization, {"x": x}) for x in [2.0, 2.0, 3.0, 4.0] ] means = [[2.0, 4.0], [3.0, 5.0], [7.0, 8.0], [9.0, 10.0]] sems = [[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]] @@ -894,29 +895,27 @@ def test_efficient_loo_cv_with_fully_bayesian_model(self) -> None: experiment = get_branin_experiment(with_batch=True, with_completed_batch=True) # Create adapter with SaasFullyBayesianSingleTaskGP + generator = BoTorchGenerator( + surrogate=Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ + ModelConfig( + botorch_model_class=SaasFullyBayesianSingleTaskGP, + ) + ], + ), + ) + ) adapter = TorchAdapter( experiment=experiment, - generator=BoTorchGenerator( - surrogate=Surrogate( - surrogate_spec=SurrogateSpec( - model_configs=[ - ModelConfig( - botorch_model_class=SaasFullyBayesianSingleTaskGP, - ) - ], - ), - ) - ), + generator=generator, transforms=[UnitX], ) # We need to mock the MCMC fitting to avoid running actual NUTS sampling # which is very slow. Instead, we'll inject mock MCMC samples. - surrogate = adapter.generator.surrogate # pyre-ignore[16] - model = surrogate.model - - # Verify the model is a SaasFullyBayesianSingleTaskGP - self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP) + surrogate = generator.surrogate + model = assert_is_instance(surrogate.model, SaasFullyBayesianSingleTaskGP) # Get training data shape info train_X = model.train_inputs[0] diff --git a/ax/adapter/tests/test_prediction_utils.py b/ax/adapter/tests/test_prediction_utils.py index f72c20fe070..080cfe47015 100644 --- a/ax/adapter/tests/test_prediction_utils.py +++ b/ax/adapter/tests/test_prediction_utils.py @@ -83,14 +83,11 @@ def test_predict_by_features(self) -> None: @mock.patch("ax.adapter.random.RandomAdapter.predict") @mock.patch("ax.adapter.random.RandomAdapter") - # pyre-fixme[3]: Return type must be annotated. def test_predict_by_features_with_non_predicting_model( self, - # pyre-fixme[2]: Parameter must be annotated. - adapter_mock, - # pyre-fixme[2]: Parameter must be annotated. - predict_mock, - ): + adapter_mock: mock.MagicMock, + predict_mock: mock.MagicMock, + ) -> None: ax_client = _set_up_client_for_get_model_predictions_no_next_trial() _attach_completed_trials(ax_client) diff --git a/ax/adapter/tests/test_random_adapter.py b/ax/adapter/tests/test_random_adapter.py index 1cb4b4522a6..689c7ef121b 100644 --- a/ax/adapter/tests/test_random_adapter.py +++ b/ax/adapter/tests/test_random_adapter.py @@ -23,6 +23,7 @@ from ax.exceptions.core import SearchSpaceExhausted from ax.generators.random.base import RandomGenerator from ax.generators.random.sobol import SobolGenerator +from ax.generators.types import TConfig from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_data, @@ -45,7 +46,7 @@ def setUp(self) -> None: ] self.search_space = SearchSpace(self.parameters, parameter_constraints) self.experiment = Experiment(search_space=self.search_space) - self.model_gen_options = {"option": "yes"} + self.model_gen_options: TConfig = {"option": "yes"} def test_fit(self) -> None: adapter = RandomAdapter(experiment=self.experiment, generator=RandomGenerator()) @@ -79,10 +80,6 @@ def test_gen_w_constraints(self) -> None: pending_observations={}, fixed_features=ObservationFeatures({"z": 3.0}), optimization_config=None, - # pyre-fixme[6]: For 6th param expected `Optional[Dict[str, - # Union[None, Dict[str, typing.Any], OptimizationConfig, - # AcquisitionFunction, float, int, str]]]` but got `Dict[str, - # str]`. model_gen_options=self.model_gen_options, ) gen_args = mock_gen.mock_calls[0][2] @@ -129,10 +126,6 @@ def test_gen_simple(self) -> None: pending_observations={}, fixed_features=ObservationFeatures({}), optimization_config=None, - # pyre-fixme[6]: For 6th param expected `Optional[Dict[str, - # Union[None, Dict[str, typing.Any], OptimizationConfig, - # AcquisitionFunction, float, int, str]]]` but got `Dict[str, - # str]`. model_gen_options=self.model_gen_options, ) gen_args = mock_gen.mock_calls[0][2] diff --git a/ax/adapter/tests/test_torch_adapter.py b/ax/adapter/tests/test_torch_adapter.py index 8c5288c865c..56d52e30271 100644 --- a/ax/adapter/tests/test_torch_adapter.py +++ b/ax/adapter/tests/test_torch_adapter.py @@ -632,7 +632,7 @@ def test_convert_experiment_data(self) -> None: ordinal_features=[2], discrete_choices={2: list(range(0, 11))}, task_features=[2] if use_task else [], - target_values={2: 0} if use_task else {}, # pyre-ignore + target_values={2: 0.0} if use_task else {}, ) converted_datasets, ordered_outcomes, _ = adapter._convert_experiment_data( experiment_data=experiment_data, diff --git a/ax/adapter/tests/test_torch_moo_adapter.py b/ax/adapter/tests/test_torch_moo_adapter.py index 6e23fb669b4..b01c3dbcd2e 100644 --- a/ax/adapter/tests/test_torch_moo_adapter.py +++ b/ax/adapter/tests/test_torch_moo_adapter.py @@ -322,8 +322,7 @@ def test_hypervolume(self, _, cuda: bool = False) -> None: ) for trial in exp.trials.values(): trial.mark_running(no_runner_required=True).mark_completed() - # pyre-fixme[16]: Optional type has no attribute `metrics`. - metrics_dict = exp.optimization_config.metrics + metrics_dict = none_throws(exp.optimization_config).metrics # Objective thresholds and synthetic observations chosen to have closed-form # hypervolumes to test. objective_thresholds = [ diff --git a/ax/adapter/tests/test_utils.py b/ax/adapter/tests/test_utils.py index a4051f12e17..30853afd2d5 100644 --- a/ax/adapter/tests/test_utils.py +++ b/ax/adapter/tests/test_utils.py @@ -85,7 +85,6 @@ def test_extract_outcome_constraints(self) -> None: OutcomeConstraint(metric=Metric("m1"), op=ComparisonOp.LEQ, bound=0) ] res = extract_outcome_constraints(outcome_constraints, outcomes) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. self.assertEqual(res[0].shape, (1, 3)) self.assertListEqual(list(res[0][0]), [1, 0, 0]) self.assertEqual(res[1][0][0], 0) @@ -137,10 +136,8 @@ def test_extract_objective_thresholds(self) -> None: outcomes=outcomes, ) expected_obj_t_not_nan = np.array([2.0, 3.0, 4.0]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. self.assertTrue(np.array_equal(obj_t[:3], expected_obj_t_not_nan[:3])) self.assertTrue(np.isnan(obj_t[-1])) - # pyre-fixme[16]: Optional type has no attribute `shape`. self.assertEqual(obj_t.shape[0], 4) # Returns NaN for objectives without a threshold. diff --git a/ax/adapter/transforms/tests/test_base_transform.py b/ax/adapter/transforms/tests/test_base_transform.py index e89cec7080f..c2e932daeef 100644 --- a/ax/adapter/transforms/tests/test_base_transform.py +++ b/ax/adapter/transforms/tests/test_base_transform.py @@ -18,6 +18,7 @@ from ax.core.objective import Objective from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig +from ax.core.types import TParameterization from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment @@ -66,10 +67,10 @@ def test_TransformObservations(self) -> None: means = np.array([3.0, 4.0]) metric_signatures = ["a", "b"] covariance = np.array([[1.0, 2.0], [3.0, 4.0]]) - parameters = {"x": 1.0, "y": "cat"} + parameters: TParameterization = {"x": 1.0, "y": "cat"} arm_name = "armmy" observation = Observation( - features=ObservationFeatures(parameters=parameters), # pyre-ignore + features=ObservationFeatures(parameters=parameters), data=ObservationData( metric_signatures=metric_signatures, means=means, covariance=covariance ), diff --git a/ax/adapter/transforms/tests/test_cast_transform.py b/ax/adapter/transforms/tests/test_cast_transform.py index 0dc6940f10c..d41de7085ee 100644 --- a/ax/adapter/transforms/tests/test_cast_transform.py +++ b/ax/adapter/transforms/tests/test_cast_transform.py @@ -35,6 +35,7 @@ ) from pandas import DataFrame from pandas.testing import assert_frame_equal +from pyre_extensions import none_throws class CastTransformTest(TestCase): @@ -179,8 +180,7 @@ def test_transform_observation_features_HSS(self) -> None: self.assertIn(p_name, obsf.parameters) # Check that full parameterization is recorded in metadata self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `get`. - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss.parameters, ) @@ -197,7 +197,7 @@ def test_transform_observation_features_HSS(self) -> None: self.assertIn(p_name, obsf.parameters) # Check that full parameterization is recorded in metadata self.assertEqual( - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss.parameters, ) @@ -245,8 +245,7 @@ def test_untransform_observation_features_HSS(self) -> None: }, ) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `get`. - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss.parameters, ) @@ -264,7 +263,7 @@ def test_untransform_observation_features_HSS(self) -> None: }, ) self.assertEqual( - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss_2.parameters, ) diff --git a/ax/adapter/transforms/tests/test_choice_encode_transform.py b/ax/adapter/transforms/tests/test_choice_encode_transform.py index 834d6714f80..5bf8ed02208 100644 --- a/ax/adapter/transforms/tests/test_choice_encode_transform.py +++ b/ax/adapter/transforms/tests/test_choice_encode_transform.py @@ -247,8 +247,9 @@ def test_hss_dependents_are_preserved(self) -> None: # x0 should be untouched because it's a fixed parameter. self.assertIsInstance(hss.parameters["x0"], FixedParameter) self.assertEqual(hss.parameters["x0"].parameter_type, ParameterType.BOOL) - # pyre-ignore[16] # Pyre doesn't understand fixed parameters have `.value` - self.assertEqual(hss.parameters["x0"].value, True) + self.assertEqual( + assert_is_instance(hss.parameters["x0"], FixedParameter).value, True + ) self.assertEqual(hss.parameters["x0"].dependents, {True: ["x1", "x2"]}) self.assertFalse(hss.parameters["x1"].is_hierarchical) diff --git a/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py b/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py index fee4d0bab9d..34d9ae1944c 100644 --- a/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py +++ b/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py @@ -72,7 +72,7 @@ def test_num_choices(self) -> None: "e", lower=3, upper=5, parameter_type=ParameterType.INT ), } - search_space = SearchSpace(parameters=parameters.values()) # pyre-ignore[6] + search_space = SearchSpace(parameters=list(parameters.values())) # Don't specify max_choices (should be set to inf) t = IntRangeToChoice(search_space=search_space) diff --git a/ax/adapter/transforms/tests/test_logit_transform.py b/ax/adapter/transforms/tests/test_logit_transform.py index 38fb28afc8f..1c567c79a76 100644 --- a/ax/adapter/transforms/tests/test_logit_transform.py +++ b/ax/adapter/transforms/tests/test_logit_transform.py @@ -18,6 +18,7 @@ from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_experiment_with_observations from pandas.testing import assert_frame_equal, assert_series_equal +from pyre_extensions import assert_is_instance from scipy.special import expit, logit @@ -107,16 +108,19 @@ def test_InvalidSettings(self) -> None: def test_TransformSearchSpace(self) -> None: ss2 = deepcopy(self.search_space) ss2 = self.t.transform_search_space(ss2) - # pyre-fixme[16]: `Parameter` has no attribute `lower`. - self.assertEqual(ss2.parameters["x"].lower, logit(0.9)) - # pyre-fixme[16]: `Parameter` has no attribute `upper`. - self.assertEqual(ss2.parameters["x"].upper, logit(0.999)) + self.assertEqual( + assert_is_instance(ss2.parameters["x"], RangeParameter).lower, logit(0.9) + ) + self.assertEqual( + assert_is_instance(ss2.parameters["x"], RangeParameter).upper, logit(0.999) + ) t2 = Logit(search_space=self.search_space_with_target) ss_target = deepcopy(self.search_space_with_target) t2.transform_search_space(ss_target) self.assertEqual(ss_target.parameters["x"].target_value, logit(0.123)) - self.assertEqual(ss_target.parameters["x"].lower, logit(0.1)) - self.assertEqual(ss_target.parameters["x"].upper, logit(0.3)) + x_param = assert_is_instance(ss_target.parameters["x"], RangeParameter) + self.assertEqual(x_param.lower, logit(0.1)) + self.assertEqual(x_param.upper, logit(0.3)) def test_transform_experiment_data(self) -> None: parameterizations = [ diff --git a/ax/adapter/transforms/tests/test_metrics_as_task_transform.py b/ax/adapter/transforms/tests/test_metrics_as_task_transform.py index 44472179380..dd287754588 100644 --- a/ax/adapter/transforms/tests/test_metrics_as_task_transform.py +++ b/ax/adapter/transforms/tests/test_metrics_as_task_transform.py @@ -14,6 +14,7 @@ from ax.core.parameter import ChoiceParameter from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_search_space_for_range_values +from pyre_extensions import assert_is_instance class MetricsAsTaskTransformTest(TestCase): @@ -125,10 +126,10 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(len(new_ss.parameters), 3) new_param = new_ss.parameters["METRIC_TASK"] self.assertIsInstance(new_param, ChoiceParameter) + new_param_choice = assert_is_instance(new_param, ChoiceParameter) self.assertEqual( - # pyre-fixme[16]: `Parameter` has no attribute `values`. - new_param.values, + new_param_choice.values, ["TARGET", "metric1", "metric2"], ) - self.assertTrue(new_param.is_task) # pyre-ignore + self.assertTrue(new_param_choice.is_task) self.assertEqual(new_param.target_value, "TARGET") diff --git a/ax/adapter/transforms/tests/test_one_hot_transform.py b/ax/adapter/transforms/tests/test_one_hot_transform.py index fb7f41accf4..f685c6b6d67 100644 --- a/ax/adapter/transforms/tests/test_one_hot_transform.py +++ b/ax/adapter/transforms/tests/test_one_hot_transform.py @@ -25,6 +25,7 @@ from ax.utils.testing.core_stubs import get_experiment_with_observations from pandas import DataFrame from pandas.testing import assert_frame_equal +from pyre_extensions import assert_is_instance class OneHotTransformTest(TestCase): @@ -130,10 +131,18 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(ss2.parameters["d"].parameter_type, ParameterType.FLOAT) # Parameter range fixed to [0,1]. - # pyre-fixme[16]: `Parameter` has no attribute `lower`. - self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "0"].lower, 0.0) - # pyre-fixme[16]: `Parameter` has no attribute `upper`. - self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "1"].upper, 1.0) + self.assertEqual( + assert_is_instance( + ss2.parameters["b" + OH_PARAM_INFIX + "0"], RangeParameter + ).lower, + 0.0, + ) + self.assertEqual( + assert_is_instance( + ss2.parameters["b" + OH_PARAM_INFIX + "1"], RangeParameter + ).upper, + 1.0, + ) self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.BOOL) # Ensure we error if we try to transform a fidelity parameter diff --git a/ax/adapter/transforms/tests/test_task_encode_transform.py b/ax/adapter/transforms/tests/test_task_encode_transform.py index 42e95d25723..8a985685c32 100644 --- a/ax/adapter/transforms/tests/test_task_encode_transform.py +++ b/ax/adapter/transforms/tests/test_task_encode_transform.py @@ -13,6 +13,7 @@ from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance class TaskChoiceToIntTaskChoiceTransformTest(TestCase): @@ -73,8 +74,9 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(ss2.parameters["b"].parameter_type, ParameterType.FLOAT) self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.INT) - # pyre-fixme[16]: `Parameter` has no attribute `values`. - self.assertEqual(ss2.parameters["c"].values, [0, 1]) + self.assertEqual( + assert_is_instance(ss2.parameters["c"], ChoiceParameter).values, [0, 1] + ) self.assertEqual(ss2.parameters["c"].target_value, 0) self.assertEqual(ss2.parameters["c"].dependents, {0: ["b"]}) diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 8a2c201f4b9..473f2ac68c6 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -212,11 +212,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.STAGED], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.STAGED + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.STAGED + ) ) self.assertIsNotNone(self.batch.time_staged) self.assertTrue(self.batch.status.is_deployed) @@ -240,11 +240,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.RUNNING], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.RUNNING + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.RUNNING + ) ) self.assertIsNotNone(self.batch.time_run_started) self.assertTrue(self.batch.status.expecting_data) @@ -261,11 +261,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.COMPLETED], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.COMPLETED + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.COMPLETED + ) ) self.assertIsNotNone(self.batch.time_completed) self.assertTrue(self.batch.status.is_terminal) @@ -296,11 +296,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.CANDIDATE], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.CANDIDATE + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.CANDIDATE + ) ) def test_AbandonBatchTrial(self) -> None: @@ -592,11 +592,9 @@ def test_get_candidate_metadata_from_all_generator_runs(self) -> None: # Check that if we add cand. metadata to gr_2, it will appear in cand. # metadata for the batch. gr_3 = get_generator_run2() - new_cand_metadata = { + new_cand_metadata: dict[str, dict[str, str] | None] | None = { a.signature: {"md_key": f"md_val_{a.signature}"} for a in gr_3.arms } - # pyre-fixme[8]: Attribute has type `Optional[Dict[str, Optional[Dict[str, - # typing.Any]]]]`; used as `Dict[str, Dict[str, str]]`. gr_3._candidate_metadata_by_arm_signature = new_cand_metadata self.batch.add_generator_run(gr_3) gr_3 = self.batch._generator_runs[-1] diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 9360c93293f..ed3b22900d8 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -17,6 +17,7 @@ get_optimization_config, get_search_space, ) +from pyre_extensions import none_throws GENERATOR_RUN_STR = "GeneratorRun(3 arms, total weight 3.0)" @@ -31,7 +32,7 @@ def setUp(self) -> None: self.search_space = get_search_space() self.arms = get_arms() - self.weights = [2, 1, 1] + self.weights: list[float] = [2, 1, 1] self.unweighted_run = GeneratorRun( arms=self.arms, optimization_config=self.optimization_config, @@ -42,8 +43,6 @@ def setUp(self) -> None: ) self.weighted_run = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=self.optimization_config, search_space=self.search_space, @@ -56,13 +55,13 @@ def setUp(self) -> None: def test_Init(self) -> None: self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `outcome_constraints`. - len(self.unweighted_run.optimization_config.outcome_constraints), + len( + none_throws(self.unweighted_run.optimization_config).outcome_constraints + ), len(self.optimization_config.outcome_constraints), ) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `parameters`. - len(self.unweighted_run.search_space.parameters), + len(none_throws(self.unweighted_run.search_space).parameters), len(self.search_space.parameters), ) self.assertEqual(str(self.unweighted_run), GENERATOR_RUN_STR) @@ -120,8 +119,6 @@ def test_ModelPredictions(self) -> None: ) run_no_model_predictions = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=get_optimization_config(), search_space=get_search_space(), @@ -150,8 +147,6 @@ def test_ParamDf(self) -> None: def test_BestArm(self) -> None: generator_run = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=get_optimization_config(), search_space=get_search_space(), @@ -166,8 +161,6 @@ def test_GenMetadata(self) -> None: gm = {"hello": "world"} generator_run = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=get_optimization_config(), search_space=get_search_space(), @@ -178,14 +171,10 @@ def test_GenMetadata(self) -> None: def test_Sortable(self) -> None: generator_run1 = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, ) generator_run2 = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, ) self.assertTrue(generator_run1 < generator_run2) diff --git a/ax/core/tests/test_objective.py b/ax/core/tests/test_objective.py index 57d21c66714..6e28cc39b09 100644 --- a/ax/core/tests/test_objective.py +++ b/ax/core/tests/test_objective.py @@ -64,8 +64,7 @@ def test_Init(self) -> None: def test_MultiObjective(self) -> None: with self.assertRaises(NotImplementedError): - # pyre-fixme[7]: Expected `None` but got `Metric`. - return self.multi_objective.metric + self.multi_objective.metric self.assertEqual(self.multi_objective.metrics, list(self.metrics.values())) minimizes = [obj.minimize for obj in self.multi_objective.objectives] @@ -106,8 +105,7 @@ def test_MultiObjective(self) -> None: def test_ScalarizedObjective(self) -> None: with self.assertRaises(NotImplementedError): - # pyre-fixme[7]: Expected `None` but got `Metric`. - return self.scalarized_objective.metric + self.scalarized_objective.metric self.assertEqual( self.scalarized_objective.metrics, [self.metrics["m1"], self.metrics["m2"]] diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c18ccf25deb..c812bd69c2b 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -36,21 +36,18 @@ class ObservationsTest(TestCase): def test_ObservationFeatures(self) -> None: t = np.datetime64("now") + obsf = ObservationFeatures( + parameters={"x": 0, "y": "a"}, + trial_index=2, + start_time=t, # pyre-ignore[6]: datetime64 vs Timestamp. + end_time=t, # pyre-ignore[6]: datetime64 vs Timestamp. + ) attrs = { "parameters": {"x": 0, "y": "a"}, "trial_index": 2, "start_time": t, "end_time": t, } - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Dict[str, typing.Any]]` - # but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[int64]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Timestamp]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - obsf = ObservationFeatures(**attrs) for k, v in attrs.items(): self.assertEqual(getattr(obsf, k), v) printstr = ( @@ -58,29 +55,21 @@ def test_ObservationFeatures(self) -> None: f"start_time={t}, end_time={t})" ) self.assertEqual(repr(obsf), printstr) - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Dict[str, typing.Any]]` - # but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[int64]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Timestamp]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - obsf2 = ObservationFeatures(**attrs) + obsf2 = ObservationFeatures( + parameters={"x": 0, "y": "a"}, + trial_index=2, + start_time=t, # pyre-ignore[6]: datetime64 vs Timestamp. + end_time=t, # pyre-ignore[6]: datetime64 vs Timestamp. + ) self.assertEqual(hash(obsf), hash(obsf2)) a = {obsf, obsf2} self.assertEqual(len(a), 1) self.assertEqual(obsf, obsf2) - attrs.pop("trial_index") - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Dict[str, typing.Any]]` - # but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[int64]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Timestamp]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - obsf3 = ObservationFeatures(**attrs) + obsf3 = ObservationFeatures( + parameters={"x": 0, "y": "a"}, + start_time=t, # pyre-ignore[6]: datetime64 vs Timestamp. + end_time=t, # pyre-ignore[6]: datetime64 vs Timestamp. + ) self.assertNotEqual(obsf, obsf3) self.assertFalse(obsf == 1) @@ -105,12 +94,9 @@ def test_ObservationFeaturesFromArm(self) -> None: self.assertEqual(obsf.trial_index, 3) def test_UpdateFeatures(self) -> None: - parameters = {"x": 0, "y": "a"} - new_parameters = {"z": "foo"} + parameters: TParameterization = {"x": 0, "y": "a"} + new_parameters: TParameterization = {"z": "foo"} - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Dict[str, Union[int, str]]`. - # pyre-fixme[6]: For 2nd param expected `Optional[int64]` but got `int`. obsf = ObservationFeatures(parameters=parameters, trial_index=3) # Ensure None trial_index doesn't override existing value @@ -119,8 +105,6 @@ def test_UpdateFeatures(self) -> None: # Test override new_obsf = ObservationFeatures( - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Dict[str, str]`. parameters=new_parameters, trial_index=4, start_time=pd.Timestamp("2005-02-25"), @@ -133,16 +117,19 @@ def test_UpdateFeatures(self) -> None: self.assertEqual(obsf.end_time, pd.Timestamp("2005-02-26")) def test_ObservationData(self) -> None: + metric_signatures = ["a", "b"] + means = np.array([4.0, 5.0]) + covariance = np.array([[1.0, 4.0], [3.0, 6.0]]) + obsd = ObservationData( + metric_signatures=metric_signatures, + means=means, + covariance=covariance, + ) attrs = { - "metric_signatures": ["a", "b"], - "means": np.array([4.0, 5.0]), - "covariance": np.array([[1.0, 4.0], [3.0, 6.0]]), + "metric_signatures": metric_signatures, + "means": means, + "covariance": covariance, } - # pyre-fixme[6]: For 1st param expected `List[str]` but got - # `Union[List[str], ndarray]`. - # pyre-fixme[6]: For 1st param expected `ndarray` but got `Union[List[str], - # ndarray]`. - obsd = ObservationData(**attrs) self.assertEqual(obsd.metric_signatures, attrs["metric_signatures"]) self.assertTrue(np.array_equal(obsd.means, attrs["means"])) self.assertTrue(np.array_equal(obsd.covariance, attrs["covariance"])) @@ -258,19 +245,18 @@ def test_ObservationsFromData(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[Dict[str, Union[int, str]], float, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[Dict[str, Union[int, str]], float, - # str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} trials = { obs["trial_index"]: Trial( - experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) + experiment, + GeneratorRun(arms=[arms[assert_is_instance(obs["arm_name"], str)]]), ) for obs in truth } @@ -375,8 +361,7 @@ def test_ObservationsFromMapData(self) -> None: arms = [ Arm( name=assert_is_instance(obs["arm_name"], str), - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool,... - parameters=obs["parameters"], + parameters=assert_is_instance(obs["parameters"], dict), ) for obs in truth ] @@ -419,10 +404,20 @@ def test_ObservationsFromMapData(self) -> None: self.assertEqual(obs.features.trial_index, t["trial_index"]) self.assertEqual(obs.data.metric_signatures, [t["metric_name"]]) self.assertEqual(obs.data.metric_signatures, [t["metric_signature"]]) - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty... - self.assertTrue(np.array_equal(obs.data.means, t["mean_t"])) - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty... - self.assertTrue(np.array_equal(obs.data.covariance, t["covariance_t"])) + self.assertTrue( + np.array_equal( + obs.data.means, + # pyre-ignore[6]: numpy stubs type mismatch. + assert_is_instance(t["mean_t"], np.ndarray), + ) + ) + self.assertTrue( + np.array_equal( + obs.data.covariance, + # pyre-ignore[6]: numpy stubs type mismatch. + assert_is_instance(t["covariance_t"], np.ndarray), + ) + ) self.assertEqual(obs.arm_name, t["arm_name"]) self.assertEqual(obs.features.metadata, {"step": t["step"]}) @@ -514,37 +509,28 @@ def test_ObservationsFromDataAbandoned(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[Dict[str, Union[float, str]], Dict[str, Union[int, str]], float, - # ndarray, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[Dict[str, Union[float, str]], - # Dict[str, Union[int, str]], float, ndarray, str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} - trials = { - obs["trial_index"]: ( - Trial(experiment, GeneratorRun(arms=[arms[obs["arm_name"]]])) + trials: dict[int, Trial | BatchTrial] = { + assert_is_instance(obs["trial_index"], int): ( + Trial( + experiment, + GeneratorRun(arms=[arms[assert_is_instance(obs["arm_name"], str)]]), + ) ) for obs in truth[:-1] - # pyre-fixme[16]: Item `Dict` of `Union[Dict[str, typing.Union[float, - # str]], Dict[str, typing.Union[int, str]], float, ndarray, str]` has no - # attribute `startswith`. - if not obs["arm_name"].startswith("2") + if not assert_is_instance(obs["arm_name"], str).startswith("2") } batch = BatchTrial(experiment, GeneratorRun(arms=[arms["2_0"], arms["2_1"]])) - # pyre-fixme[6]: For 1st param expected - # `SupportsKeysAndGetItem[Union[Dict[str, Union[float, str]], Dict[str, - # Union[int, str]], float, ndarray, str], Trial]` but got `Dict[int, - # BatchTrial]`. - trials.update({2: batch}) - # pyre-fixme[16]: Optional type has no attribute `mark_abandoned`. - trials.get(1).mark_abandoned() - # pyre-fixme[16]: Optional type has no attribute `mark_arm_abandoned`. - trials.get(2).mark_arm_abandoned(arm_name="2_1") + trials[2] = batch + none_throws(trials.get(1)).mark_abandoned() + assert_is_instance(trials.get(2), BatchTrial).mark_arm_abandoned(arm_name="2_1") type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) type(experiment).metrics = PropertyMock( @@ -627,19 +613,18 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[None, Dict[str, Union[int, str]], float, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[None, Dict[str, Union[int, str]], - # float, str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} trials = { obs["trial_index"]: Trial( - experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) + experiment, + GeneratorRun(arms=[arms[assert_is_instance(obs["arm_name"], str)]]), ) for obs in truth } @@ -789,11 +774,19 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial( self.assertEqual( obs.data.metric_signatures, obs_truth["metric_signatures"][i] ) - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty... - self.assertTrue(np.array_equal(obs.data.means, obs_truth["means"][i])) self.assertTrue( - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtyp... - np.array_equal(obs.data.covariance, obs_truth["covariance"][i]) + np.array_equal( + obs.data.means, + # pyre-ignore[6]: numpy stubs type mismatch. + assert_is_instance(obs_truth["means"][i], np.ndarray), + ) + ) + self.assertTrue( + np.array_equal( + obs.data.covariance, + # pyre-ignore[6]: numpy stubs type mismatch. + assert_is_instance(obs_truth["covariance"][i], np.ndarray), + ) ) self.assertEqual(obs.arm_name, obs_truth["arm_name"][i]) self.assertEqual(obs.arm_name, obs_truth["arm_name"][i]) @@ -875,12 +868,10 @@ def test_ObservationsWithCandidateMetadata(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[Dict[str, Union[int, str]], float, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[Dict[str, Union[int, str]], float, - # str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() @@ -889,9 +880,9 @@ def test_ObservationsWithCandidateMetadata(self) -> None: obs["trial_index"]: Trial( experiment, GeneratorRun( - arms=[arms[obs["arm_name"]]], + arms=[arms[assert_is_instance(obs["arm_name"], str)]], candidate_metadata_by_arm_signature={ - arms[obs["arm_name"]].signature: { + arms[assert_is_instance(obs["arm_name"], str)].signature: { SOME_METADATA_KEY: f"value_{obs['trial_index']}" } }, @@ -919,8 +910,7 @@ def test_ObservationsWithCandidateMetadata(self) -> None: observations = observations_from_data(experiment, data) for observation in observations: self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `get`. - observation.features.metadata.get(SOME_METADATA_KEY), + none_throws(observation.features.metadata).get(SOME_METADATA_KEY), f"value_{observation.features.trial_index}", ) diff --git a/ax/core/tests/test_outcome_constraint.py b/ax/core/tests/test_outcome_constraint.py index fcc6ed79cf5..f288ece1703 100644 --- a/ax/core/tests/test_outcome_constraint.py +++ b/ax/core/tests/test_outcome_constraint.py @@ -253,8 +253,7 @@ def test_RaiseError(self) -> None: ) with self.assertRaises(NotImplementedError): - # pyre-fixme[7]: Expected `None` but got `Metric`. - return self.constraint.metric + self.constraint.metric with self.assertRaises(NotImplementedError): self.constraint.metric = self.metrics[0] diff --git a/ax/core/tests/test_parameter_constraint.py b/ax/core/tests/test_parameter_constraint.py index 822046e1ec8..30bc1be4c2c 100644 --- a/ax/core/tests/test_parameter_constraint.py +++ b/ax/core/tests/test_parameter_constraint.py @@ -82,16 +82,12 @@ def test_Repr(self) -> None: self.assertEqual(str(self.constraint), self.constraint_repr) def test_Validate(self) -> None: - parameters = {"x": 4, "z": 3} + parameters: dict[str, float | int] = {"x": 4, "z": 3} with self.assertRaises(ValueError): - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[float, int]]` - # but got `Dict[str, int]`. self.constraint.check(parameters) # check slack constraint parameters = {"x": 4, "y": 1} - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[float, int]]` but - # got `Dict[str, int]`. self.assertTrue(self.constraint.check(parameters)) # check tight constraint (within numerical tolerance) diff --git a/ax/core/tests/test_runner.py b/ax/core/tests/test_runner.py index 24edb5eb61e..40a34d33276 100644 --- a/ax/core/tests/test_runner.py +++ b/ax/core/tests/test_runner.py @@ -15,8 +15,7 @@ class DummyRunner(Runner): - # pyre-fixme[3]: Return type must be annotated. - def run(self, trial: BaseTrial): + def run(self, trial: BaseTrial) -> dict[str, str]: return {"metadatum": f"value_for_trial_{trial.index}"} diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 33eef439201..d638968896d 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -114,13 +114,12 @@ def test_basic_properties(self) -> None: def test_adding_new_trials(self) -> None: new_arm = get_arms()[1] - cand_metadata = {new_arm.signature: {"a": "b"}} + cand_metadata: dict[str, dict[str, str] | None] = { + new_arm.signature: {"a": "b"} + } new_trial = self.experiment.new_trial( generator_run=GeneratorRun( arms=[new_arm], - # pyre-fixme[6]: For 2nd param expected `Optional[Dict[str, - # Optional[Dict[str, typing.Any]]]]` but got `Dict[str, Dict[str, - # str]]`. candidate_metadata_by_arm_signature=cand_metadata, ) ) @@ -313,15 +312,13 @@ def stop(self, trial, reason): f"{BaseTrial.__module__}.{BaseTrial.__name__}.lookup_data", return_value=TEST_DATA, ) - # pyre-fixme[3]: Return type must be annotated. - def test_objective_mean(self, _mock): + def test_objective_mean(self, _mock: Mock) -> None: self.assertEqual(self.trial.objective_mean, 1.0) @patch( f"{BaseTrial.__module__}.{BaseTrial.__name__}.lookup_data", return_value=Data() ) - # pyre-fixme[3]: Return type must be annotated. - def test_objective_mean_empty_df(self, _mock): + def test_objective_mean_empty_df(self, _mock: Mock) -> None: with self.assertRaisesRegex(ValueError, "not yet in data for trial."): self.assertIsNone(self.trial.objective_mean) diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index f0810e2949c..ccd64b6c4ec 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -22,6 +22,7 @@ from ax.core.observation import ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import OutcomeConstraint +from ax.core.trial import Trial from ax.core.trial_status import TrialStatus from ax.core.types import ComparisonOp from ax.core.utils import ( @@ -51,7 +52,7 @@ get_experiment, get_hierarchical_search_space_experiment, ) -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws class UtilsTest(TestCase): @@ -1151,7 +1152,7 @@ def test_curve_data(self) -> None: ) trial = exp.trials[0] trial.mark_running(no_runner_required=True) - arm_name = trial.arm.name # pyre-ignore[16] + arm_name = none_throws(assert_is_instance(trial, Trial).arm).name # Both metrics present at various steps → COMPLETE. df_both = pd.DataFrame( @@ -1183,7 +1184,7 @@ def test_curve_data(self) -> None: exp2.optimization_config = none_throws(exp.optimization_config) trial2 = exp2.trials[0] trial2.mark_running(no_runner_required=True) - arm_name2 = trial2.arm.name # pyre-ignore[16] + arm_name2 = none_throws(assert_is_instance(trial2, Trial).arm).name df_partial = pd.DataFrame( [ { From 7bf020e8dfd584b6348a464bc7488beca8de5f57 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 6 Mar 2026 10:19:54 -0800 Subject: [PATCH 3/3] Remove pyre-fixme/pyre-ignore from ax/service, ax/storage, ax/utils test files (#4990) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/4990 Remove pyre-fixme and pyre-ignore type suppression comments from test files in ax/service/tests, ax/storage/*/tests, and ax/utils/*/tests. Uses proper type narrowing via none_throws, assert_is_instance, cast, and explicit type annotations instead of suppression comments. Differential Revision: D95273568 --- ax/service/tests/test_ax_client.py | 314 ++++++++---------- ax/service/tests/test_global_stopping.py | 7 +- ax/service/tests/test_instantiation_utils.py | 7 +- ax/service/tests/test_interactive_loop.py | 33 +- ax/service/tests/test_managed_loop.py | 95 +++--- ax/service/tests/test_report_utils.py | 23 +- .../json_store/tests/test_json_store.py | 34 +- ax/storage/sqa_store/tests/test_sqa_store.py | 89 ++--- .../tests/test_with_db_settings_base.py | 18 +- ax/utils/common/tests/test_docutils.py | 9 +- ax/utils/common/tests/test_equality.py | 4 +- ax/utils/common/tests/test_executils.py | 6 +- ax/utils/common/tests/test_testutils.py | 5 +- .../testing/tests/test_backend_simulator.py | 4 +- ax/utils/testing/tests/test_mock.py | 2 +- 15 files changed, 308 insertions(+), 342 deletions(-) diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 705237fd89f..aaa48b3028d 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -10,9 +10,10 @@ import sys import time import warnings +from collections.abc import Sequence from itertools import product from math import ceil -from typing import Any, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from unittest import mock from unittest.mock import Mock, patch @@ -25,6 +26,7 @@ from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment +from ax.core.objective import MultiObjective from ax.core.optimization_config import MultiObjectiveOptimizationConfig from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint from ax.core.parameter import ( @@ -62,6 +64,7 @@ MaxGenerationParallelism, MaxTrialsAwaitingData, ) +from ax.generators.torch.botorch_modular.generator import BoTorchGenerator from ax.metrics.branin import branin, BraninMetric from ax.runners.synthetic import SyntheticRunner from ax.service.ax_client import AxClient, ObjectiveProperties @@ -117,8 +120,7 @@ def run_trials_using_recommended_parallelism( remaining_trials -= 1 for _ in range(parallelism_setting): params, idx = in_flight_trials.pop() - # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U... - ax_client.complete_trial(idx, branin(params["x"], params["y"])) + ax_client.complete_trial(idx, float(branin(params["x"], params["y"]))) # If all went well and no errors were raised, remaining_trials should be 0. return remaining_trials @@ -152,18 +154,18 @@ def get_branin_currin_optimization_with_N_sobol_trials( objectives={ "branin": ObjectiveProperties( minimize=minimize, - # pyre-fixme[6]: For 2nd param expected `Optional[float]` but got - # `Optional[Tensor]`. threshold=( - branin_currin.ref_point[0] if include_objective_thresholds else None + float(branin_currin.ref_point[0]) + if include_objective_thresholds + else None ), ), "currin": ObjectiveProperties( minimize=minimize, - # pyre-fixme[6]: For 2nd param expected `Optional[float]` but got - # `Optional[Tensor]`. threshold=( - branin_currin.ref_point[1] if include_objective_thresholds else None + float(branin_currin.ref_point[1]) + if include_objective_thresholds + else None ), ), }, @@ -333,9 +335,7 @@ def test_set_status_quo(self) -> None: ], ) self.assertIsNone(ax_client.status_quo) - status_quo_params = {"x": 1.0, "y": 1.0} - # pyre-fixme[6]: For 1st param expected `Optional[Dict[str, Union[None, - # bool, float, int, str]]]` but got `Dict[str, float]`. + status_quo_params: TParameterization = {"x": 1.0, "y": 1.0} ax_client.set_status_quo(status_quo_params) self.assertEqual( ax_client.experiment.status_quo, @@ -343,7 +343,7 @@ def test_set_status_quo(self) -> None: ) def test_status_quo_property(self) -> None: - status_quo_params = {"x": 1.0, "y": 1.0} + status_quo_params: TParameterization = {"x": 1.0, "y": 1.0} ax_client = AxClient() ax_client.create_experiment( name="test", @@ -351,8 +351,6 @@ def test_status_quo_property(self) -> None: {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], - # pyre-fixme[6]: For 3rd param expected `Optional[Dict[str, Union[None, - # bool, float, int, str]]]` but got `Dict[str, float]`. status_quo=status_quo_params, ) self.assertEqual(ax_client.status_quo, status_quo_params) @@ -379,34 +377,36 @@ def test_set_optimization_config_to_moo_with_constraints(self) -> None: }, outcome_constraints=["baz >= 7.2%"], ) - opt_config = ax_client.experiment.optimization_config + opt_config = assert_is_instance( + ax_client.experiment.optimization_config, + MultiObjectiveOptimizationConfig, + ) + objective = assert_is_instance(opt_config.objective, MultiObjective) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `objective`. - opt_config.objective.objectives[0].metric.signature, + objective.objectives[0].metric.signature, "foo", ) self.assertEqual( - opt_config.objective.objectives[0].metric.name, + objective.objectives[0].metric.name, "foo", ) self.assertEqual( - opt_config.objective.objectives[0].minimize, + objective.objectives[0].minimize, True, ) self.assertEqual( - opt_config.objective.objectives[1].metric.signature, + objective.objectives[1].metric.signature, "bar", ) self.assertEqual( - opt_config.objective.objectives[1].metric.name, + objective.objectives[1].metric.name, "bar", ) self.assertEqual( - opt_config.objective.objectives[1].minimize, + objective.objectives[1].minimize, False, ) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `objective_thresholds`. opt_config.objective_thresholds[0], ObjectiveThreshold( metric=Metric(name="foo", lower_is_better=True), @@ -425,7 +425,6 @@ def test_set_optimization_config_to_moo_with_constraints(self) -> None: ), ) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `outcome_constraints`. opt_config.outcome_constraints[0], OutcomeConstraint( metric=Metric(name="baz", lower_is_better=False), @@ -451,9 +450,8 @@ def test_set_optimization_config_to_single_objective(self) -> None: }, outcome_constraints=["baz >= 7.2%"], ) - opt_config = ax_client.experiment.optimization_config + opt_config = none_throws(ax_client.experiment.optimization_config) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `objective`. opt_config.objective.metric.signature, "foo", ) @@ -462,7 +460,6 @@ def test_set_optimization_config_to_single_objective(self) -> None: True, ) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `outcome_constraints`. opt_config.outcome_constraints[0], OutcomeConstraint( metric=Metric(name="baz", lower_is_better=False), @@ -820,12 +817,11 @@ def test_create_experiment(self) -> None: tracking_metric_names=["test_tracking_metric"], is_test=True, ) - assert ax_client._experiment is not None + experiment = none_throws(ax_client._experiment) self.assertEqual(ax_client.experiment.__class__.__name__, "Experiment") - self.assertEqual(ax_client._experiment, ax_client.experiment) + self.assertEqual(experiment, ax_client.experiment) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `search_space`. - ax_client._experiment.search_space.parameters["x"], + experiment.search_space.parameters["x"], RangeParameter( name="x", parameter_type=ParameterType.FLOAT, @@ -836,7 +832,7 @@ def test_create_experiment(self) -> None: ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["y"], + experiment.search_space.parameters["y"], ChoiceParameter( name="y", parameter_type=ParameterType.INT, @@ -845,17 +841,17 @@ def test_create_experiment(self) -> None: ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x3"], + experiment.search_space.parameters["x3"], FixedParameter(name="x3", parameter_type=ParameterType.INT, value=2), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x4"], + experiment.search_space.parameters["x4"], RangeParameter( name="x4", parameter_type=ParameterType.INT, lower=1.0, upper=3.0 ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x5"], + experiment.search_space.parameters["x5"], ChoiceParameter( name="x5", parameter_type=ParameterType.STRING, @@ -863,22 +859,22 @@ def test_create_experiment(self) -> None: ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x6"], + experiment.search_space.parameters["x6"], RangeParameter( name="x6", parameter_type=ParameterType.INT, lower=1.0, upper=3.0 ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x7"], + experiment.search_space.parameters["x7"], DerivedParameter( name="x7", parameter_type=ParameterType.INT, expression_str=expression_str, ), ) + opt_config = none_throws(experiment.optimization_config) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `optimization_config`. - ax_client._experiment.optimization_config.outcome_constraints[0], + opt_config.outcome_constraints[0], OutcomeConstraint( metric=Metric(name="some_metric", lower_is_better=False), op=ComparisonOp.GEQ, @@ -887,7 +883,7 @@ def test_create_experiment(self) -> None: ), ) self.assertEqual( - ax_client._experiment.optimization_config.outcome_constraints[1], + opt_config.outcome_constraints[1], OutcomeConstraint( metric=Metric(name="some_metric", lower_is_better=True), op=ComparisonOp.LEQ, @@ -895,15 +891,12 @@ def test_create_experiment(self) -> None: relative=False, ), ) - self.assertTrue(ax_client._experiment.optimization_config.objective.minimize) + self.assertTrue(opt_config.objective.minimize) self.assertDictEqual( - # pyre-fixme[16]: `Optional` has no attribute `_tracking_metrics`. - ax_client._experiment._tracking_metrics, + experiment._tracking_metrics, {"test_tracking_metric": Metric(name="test_tracking_metric")}, ) - # pyre-fixme[16]: `Optional` has no attribute - # `immutable_search_space_and_opt_config`. - self.assertTrue(ax_client._experiment.immutable_search_space_and_opt_config) + self.assertTrue(experiment.immutable_search_space_and_opt_config) self.assertTrue(ax_client.experiment.is_test) with self.subTest("objective_name"): @@ -1113,16 +1106,19 @@ def test_create_experiment_with_metric_definitions(self) -> None: metric_definitions=metric_definitions, is_test=True, ) - # pyre-fixme[16]: `Optional` has no attribute `objective`. - objectives = ax_client.experiment.optimization_config.objective.objectives + opt_config = assert_is_instance( + ax_client.experiment.optimization_config, + MultiObjectiveOptimizationConfig, + ) + objective = assert_is_instance(opt_config.objective, MultiObjective) + objectives = objective.objectives self.assertEqual(objectives[0].metric.signature, "obj_m1") self.assertEqual(objectives[0].metric.name, "obj_m1") self.assertEqual(objectives[0].metric.properties, {"m1_opt": "m1_val"}) self.assertEqual(objectives[1].metric.signature, "obj_m2") self.assertEqual(objectives[1].metric.name, "obj_m2") self.assertEqual(objectives[1].metric.properties, {"m2_opt": "m2_val"}) - # pyre-fixme[16]: `Optional` has no attribute `objective_thresholds`. - thresholds = ax_client.experiment.optimization_config.objective_thresholds + thresholds = opt_config.objective_thresholds self.assertEqual(thresholds[0].metric.signature, "obj_m1") self.assertEqual(thresholds[0].metric.name, "obj_m1") @@ -1131,10 +1127,7 @@ def test_create_experiment_with_metric_definitions(self) -> None: self.assertEqual(thresholds[1].metric.name, "obj_m2") self.assertEqual(thresholds[1].metric.properties, {"m2_opt": "m2_val"}) - outcome_constraints = ( - # pyre-fixme[16]: `Optional` has no attribute `outcome_constraints`. - ax_client.experiment.optimization_config.outcome_constraints - ) + outcome_constraints = opt_config.outcome_constraints self.assertEqual(outcome_constraints[0].metric.signature, "const_m3") self.assertEqual(outcome_constraints[0].metric.name, "const_m3") self.assertEqual(outcome_constraints[0].metric.properties, {"m3_opt": "m3_val"}) @@ -1218,26 +1211,26 @@ def test_set_optimization_config_with_metric_definitions(self) -> None: outcome_constraints=["const_m3 >= 3"], metric_definitions=metric_definitions, ) - # pyre-fixme[16]: `Optional` has no attribute `objective`. - objectives = ax_client.experiment.optimization_config.objective.objectives + opt_config = assert_is_instance( + ax_client.experiment.optimization_config, + MultiObjectiveOptimizationConfig, + ) + objective = assert_is_instance(opt_config.objective, MultiObjective) + objectives = objective.objectives self.assertEqual(objectives[0].metric.signature, "obj_m1") self.assertEqual(objectives[0].metric.name, "obj_m1") self.assertEqual(objectives[0].metric.properties, {"m1_opt": "m1_val"}) self.assertEqual(objectives[1].metric.signature, "obj_m2") self.assertEqual(objectives[1].metric.name, "obj_m2") self.assertEqual(objectives[1].metric.properties, {"m2_opt": "m2_val"}) - # pyre-fixme[16]: `Optional` has no attribute `objective_thresholds`. - thresholds = ax_client.experiment.optimization_config.objective_thresholds + thresholds = opt_config.objective_thresholds self.assertEqual(thresholds[0].metric.signature, "obj_m1") self.assertEqual(thresholds[0].metric.name, "obj_m1") self.assertEqual(thresholds[0].metric.properties, {"m1_opt": "m1_val"}) self.assertEqual(thresholds[1].metric.signature, "obj_m2") self.assertEqual(thresholds[1].metric.name, "obj_m2") self.assertEqual(thresholds[1].metric.properties, {"m2_opt": "m2_val"}) - outcome_constraints = ( - # pyre-fixme[16]: `Optional` has no attribute `outcome_constraints`. - ax_client.experiment.optimization_config.outcome_constraints - ) + outcome_constraints = opt_config.outcome_constraints self.assertEqual(outcome_constraints[0].metric.signature, "const_m3") self.assertEqual(outcome_constraints[0].metric.name, "const_m3") self.assertEqual(outcome_constraints[0].metric.properties, {"m3_opt": "m3_val"}) @@ -1604,11 +1597,10 @@ def test_create_moo_experiment(self) -> None: tracking_metric_names=["test_tracking_metric"], is_test=True, ) - assert ax_client._experiment is not None - self.assertEqual(ax_client._experiment, ax_client.experiment) + experiment = none_throws(ax_client._experiment) + self.assertEqual(experiment, ax_client.experiment) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `search_space`. - ax_client._experiment.search_space.parameters["x"], + experiment.search_space.parameters["x"], RangeParameter( name="x", parameter_type=ParameterType.FLOAT, @@ -1619,7 +1611,7 @@ def test_create_moo_experiment(self) -> None: ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["y"], + experiment.search_space.parameters["y"], ChoiceParameter( name="y", parameter_type=ParameterType.INT, @@ -1628,35 +1620,39 @@ def test_create_moo_experiment(self) -> None: ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x3"], + experiment.search_space.parameters["x3"], FixedParameter(name="x3", parameter_type=ParameterType.INT, value=2), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x4"], + experiment.search_space.parameters["x4"], RangeParameter( name="x4", parameter_type=ParameterType.INT, lower=1.0, upper=3.0 ), ) self.assertEqual( - ax_client._experiment.search_space.parameters["x5"], + experiment.search_space.parameters["x5"], ChoiceParameter( name="x5", parameter_type=ParameterType.STRING, values=["one", "two", "three"], ), ) - # pyre-fixme[16]: `Optional` has no attribute `optimization_config`. - optimization_config = ax_client._experiment.optimization_config + optimization_config = assert_is_instance( + experiment.optimization_config, MultiObjectiveOptimizationConfig + ) + multi_objective = assert_is_instance( + optimization_config.objective, MultiObjective + ) self.assertEqual( - [m.name for m in optimization_config.objective.metrics], + [m.name for m in multi_objective.metrics], ["test_objective_1", "test_objective_2"], ) self.assertEqual( - [o.minimize for o in optimization_config.objective.objectives], + [o.minimize for o in multi_objective.objectives], [True, False], ) self.assertEqual( - [m.lower_is_better for m in optimization_config.objective.metrics], + [m.lower_is_better for m in multi_objective.metrics], [True, False], ) self.assertEqual( @@ -1694,13 +1690,10 @@ def test_create_moo_experiment(self) -> None: ), ) self.assertDictEqual( - # pyre-fixme[16]: `Optional` has no attribute `_tracking_metrics`. - ax_client._experiment._tracking_metrics, + experiment._tracking_metrics, {"test_tracking_metric": Metric(name="test_tracking_metric")}, ) - # pyre-fixme[16]: `Optional` has no attribute - # `immutable_search_space_and_opt_config`. - self.assertTrue(ax_client._experiment.immutable_search_space_and_opt_config) + self.assertTrue(experiment.immutable_search_space_and_opt_config) self.assertTrue(ax_client.experiment.is_test) with self.subTest("objective_name name raises UnsupportedError"): @@ -1738,16 +1731,15 @@ def test_raw_data_format(self) -> None: {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) + trial_index = 0 for _ in range(6): parameterization, trial_index = ax_client.get_next_trial() x, y = parameterization.get("x"), parameterization.get("y") - # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, Union... - ax_client.complete_trial(trial_index, raw_data=(branin(x, y), 0.0)) + ax_client.complete_trial(trial_index, raw_data=(float(branin(x, y)), 0.0)) with self.assertRaisesRegex( UserInputError, "Raw data does not conform to the expected structure." ): ax_client._update_trial_with_raw_data( - # pyre-fixme[61]: `trial_index` is undefined, or not always defined. trial_index=trial_index, # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[... raw_data="invalid data", @@ -1870,10 +1862,9 @@ def test_update_running_trial_with_intermediate_data(self) -> None: self.assertTrue(data.has_step_column) self.assertIn(t, data.df[MAP_KEY]) - # pyre-fixme[56]: Pyre was not able to infer the type of argument `f"{ax.service.... @patch( f"{get_best_parameters_from_model_predictions_with_trial_index.__module__}" - + ".get_best_parameters_from_model_predictions_with_trial_index", + ".get_best_parameters_from_model_predictions_with_trial_index", wraps=get_best_parameters_from_model_predictions_with_trial_index, ) def test_get_best_point_no_model_predictions( @@ -1899,17 +1890,18 @@ def test_trial_completion(self) -> None: metrics_in_data = ax_client.experiment.fetch_data().df["metric_name"].values self.assertNotIn("m1", metrics_in_data) self.assertIn("branin", metrics_in_data) - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. - self.assertEqual(ax_client.get_best_parameters()[0], params) + self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params) params2, idy = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idy, raw_data=(-1, 0.0)) - self.assertEqual(ax_client.get_best_parameters()[0], params2) + self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params2) params3, idx3 = ax_client.get_next_trial() ax_client.complete_trial(trial_index=idx3, raw_data=-2) - self.assertEqual(ax_client.get_best_parameters()[0], params3) - best_trial_values = ax_client.get_best_parameters()[1] + self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params3) + best_trial_values = none_throws(none_throws(ax_client.get_best_parameters())[1]) self.assertEqual(best_trial_values[0], {"branin": -2.0}) - self.assertTrue(math.isnan(best_trial_values[1]["branin"]["branin"])) + self.assertTrue( + math.isnan(none_throws(best_trial_values[1])["branin"]["branin"]) + ) def test_update_trial_data(self) -> None: ax_client = get_branin_optimization(support_intermediate_data=True) @@ -2000,13 +1992,11 @@ def test_ttl_trial(self) -> None: # A ttl trial that ends adds no data. params, idx = ax_client.get_next_trial(ttl_seconds=1) - # pyre-fixme[16]: `Optional` has no attribute `status`. - self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running) + self.assertTrue(ax_client.experiment.trials[idx].status.is_running) time.sleep(1) # Wait for TTL to elapse. - self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running) + self.assertTrue(ax_client.experiment.trials[idx].status.is_running) ax_client.complete_trial(trial_index=idx, raw_data=(0, 0.0)) - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. - self.assertEqual(ax_client.get_best_parameters()[0], params) + self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params) def test_fail_on_batch(self) -> None: ax_client = AxClient() @@ -2041,11 +2031,9 @@ def test_log_failure(self) -> None: ) _, idx = ax_client.get_next_trial() ax_client.log_trial_failure(idx, metadata={"dummy": "test"}) - # pyre-fixme[16]: `Optional` has no attribute `status`. - self.assertTrue(ax_client.experiment.trials.get(idx).status.is_failed) + self.assertTrue(ax_client.experiment.trials[idx].status.is_failed) self.assertEqual( - # pyre-fixme[16]: `Optional` has no attribute `run_metadata`. - ax_client.experiment.trials.get(idx).run_metadata.get("dummy"), + ax_client.experiment.trials[idx].run_metadata.get("dummy"), "test", ) with self.assertRaisesRegex(UnsupportedError, ".* no longer expects"): @@ -2063,8 +2051,7 @@ def test_attach_trial_and_get_trial_parameters(self) -> None: parameters={"x": 0.0, "y": 1.0}, arm_name=ARM_NAME ) ax_client.complete_trial(trial_index=idx, raw_data=5) - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. - self.assertEqual(ax_client.get_best_parameters()[0], params) + self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params) self.assertEqual( ax_client.get_trial_parameters(trial_index=idx), {"x": 0, "y": 1} ) @@ -2089,13 +2076,11 @@ def test_attach_trial_ttl_seconds(self) -> None: params, idx = ax_client.attach_trial( parameters={"x": 0.0, "y": 1.0}, ttl_seconds=1 ) - # pyre-fixme[16]: `Optional` has no attribute `status`. - self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running) + self.assertTrue(ax_client.experiment.trials[idx].status.is_running) time.sleep(1) # Wait for TTL to elapse. - self.assertTrue(ax_client.experiment.trials.get(idx).status.is_running) + self.assertTrue(ax_client.experiment.trials[idx].status.is_running) ax_client.complete_trial(trial_index=idx, raw_data=5) - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. - self.assertEqual(ax_client.get_best_parameters()[0], params) + self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params) self.assertEqual( ax_client.get_trial_parameters(trial_index=idx), {"x": 0, "y": 1} ) @@ -2110,8 +2095,7 @@ def test_attach_trial_numpy(self) -> None: ) params, idx = ax_client.attach_trial(parameters={"x": 0.0, "y": 1.0}) ax_client.complete_trial(trial_index=idx, raw_data=np.int32(5)) - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. - self.assertEqual(ax_client.get_best_parameters()[0], params) + self.assertEqual(none_throws(ax_client.get_best_parameters())[0], params) def test_relative_oc_without_sq(self) -> None: """Must specify status quo to have relative outcome constraint.""" @@ -2304,8 +2288,7 @@ def test_overwrite(self) -> None: parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, - # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, Union... - raw_data=branin(*parameters.values()), + raw_data=float(branin(*parameters.values())), ) with self.assertRaises(ValueError): @@ -2335,8 +2318,7 @@ def test_overwrite(self) -> None: self.assertIn("x2", parameters.keys()) ax_client.complete_trial( trial_index=trial_index, - # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, Union... - raw_data=branin(*parameters.values()), + raw_data=float(branin(*parameters.values())), ) def test_fixed_random_seed_reproducibility(self) -> None: @@ -2349,11 +2331,11 @@ def test_fixed_random_seed_reproducibility(self) -> None: ) for _ in range(5): params, idx = ax_client.get_next_trial() - # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, Union... - ax_client.complete_trial(idx, branin(params.get("x"), params.get("y"))) + ax_client.complete_trial( + idx, float(branin(params.get("x"), params.get("y"))) + ) trial_parameters_1 = [ - # pyre-fixme[16]: `BaseTrial` has no attribute `arm`. - t.arm.parameters + none_throws(assert_is_instance(t, Trial).arm).parameters for t in ax_client.experiment.trials.values() ] ax_client = AxClient(random_seed=RANDOM_SEED) @@ -2365,10 +2347,12 @@ def test_fixed_random_seed_reproducibility(self) -> None: ) for _ in range(5): params, idx = ax_client.get_next_trial() - # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, Union... - ax_client.complete_trial(idx, branin(params.get("x"), params.get("y"))) + ax_client.complete_trial( + idx, float(branin(params.get("x"), params.get("y"))) + ) trial_parameters_2 = [ - t.arm.parameters for t in ax_client.experiment.trials.values() + none_throws(assert_is_instance(t, Trial).arm).parameters + for t in ax_client.experiment.trials.values() ] self.assertEqual(trial_parameters_1, trial_parameters_2) @@ -2393,17 +2377,18 @@ def test_init_position_saved(self) -> None: with self.subTest(ax=ax_client, params=params, idx=idx): new_params, new_idx = ax_client.get_next_trial() # Sobol "init_position" setting should be saved on the generator run. + trial = assert_is_instance(ax_client.experiment.trials[idx], Trial) self.assertEqual( - # pyre-fixme[16]: `BaseTrial` has no attribute `_generator_run`. - ax_client.experiment.trials[ - idx - ]._generator_run._generator_state_after_gen["init_position"], + none_throws( + none_throws(trial._generator_run)._generator_state_after_gen + )["init_position"], idx + 1, ) self.assertEqual(params, new_params) self.assertEqual(idx, new_idx) - # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, Union... - ax_client.complete_trial(idx, branin(params.get("x"), params.get("y"))) + ax_client.complete_trial( + idx, float(branin(params.get("x"), params.get("y"))) + ) def test_unnamed_experiment_snapshot(self) -> None: ax_client = AxClient(random_seed=RANDOM_SEED) @@ -2474,13 +2459,12 @@ def test_get_model_predictions_no_next_trial_parameterizations(self) -> None: ax_client = _set_up_client_for_get_model_predictions_no_next_trial() _attach_completed_trials(ax_client) - parameterizations = { + parameterizations: dict[int, TParameterization] = { 18: {"x1": 0.3, "x2": 0.5}, 19: {"x1": 0.4, "x2": 0.5}, 20: {"x1": 0.8, "x2": 0.5}, } parameterization_predictions_dict = ax_client.get_model_predictions( - # pyre-ignore [6] parameterizations=parameterizations ) # Expect predictions for only 3 input parameterizations, @@ -2491,14 +2475,13 @@ def test_get_model_predictions_for_parameterization_no_next_trial(self) -> None: ax_client = _set_up_client_for_get_model_predictions_no_next_trial() _attach_completed_trials(ax_client) - parameterizations = [ + parameterizations_list: list[TParameterization] = [ {"x1": 0.3, "x2": 0.5}, {"x1": 0.4, "x2": 0.5}, {"x1": 0.8, "x2": 0.5}, ] predictions_list = ax_client.get_model_predictions_for_parameterizations( - # pyre-ignore [6] - parameterizations=parameterizations + parameterizations=parameterizations_list ) self.assertEqual(len(predictions_list), 3) @@ -2636,8 +2619,10 @@ def helper_test_get_pareto_optimal_points( # NOTE: model predictions are very poor due to `mock_botorch_optimize`. # This overwrites the `predict` call to return the original observations, # while testing the rest of the code as if we're using predictions. - # pyre-fixme[16]: `Optional` has no attribute `model`. - model = ax_client.generation_strategy.adapter.generator + model = assert_is_instance( + none_throws(ax_client.generation_strategy.adapter).generator, + BoTorchGenerator, + ) ys = model.surrogate.training_data[0].Y with patch.object( model, "predict", return_value=(ys, torch.zeros(*ys.shape, ys.shape[-1])) @@ -2784,10 +2769,8 @@ def test_get_pareto_optimal_points_from_sobol_step_with_constraint_minimize_fals @mock_botorch_optimize def test_get_pareto_optimal_points_objective_threshold_inference( self, - # pyre-fixme[2]: Parameter must be annotated. - mock_observed_pareto, - # pyre-fixme[2]: Parameter must be annotated. - mock_predicted_pareto, + mock_observed_pareto: Mock, + mock_predicted_pareto: Mock, ) -> None: ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, include_objective_thresholds=False @@ -3029,8 +3012,7 @@ def test_should_stop_trials_early(self) -> None: ], support_intermediate_data=True, ) - # pyre-fixme[6]: For 1st param expected `Set[int]` but got `List[int]`. - actual = ax_client.should_stop_trials_early(trial_indices=[1, 2, 3]) + actual = ax_client.should_stop_trials_early(trial_indices={1, 2, 3}) self.assertEqual(actual, expected) def test_stop_trial_early(self) -> None: @@ -3076,9 +3058,10 @@ def test_max_concurrency_exception_when_early_stopping(self) -> None: support_intermediate_data=True, ) - exception = MaxParallelismReachedException(step_index=1, num_running=10) + exception: MaxParallelismReachedException = MaxParallelismReachedException( + step_index=1, num_running=10 + ) - # pyre-fixme[53]: Captured variable `exception` is not annotated. def fake_new_trial(*args: Any, **kwargs: Any) -> None: raise exception @@ -3172,16 +3155,19 @@ def test_SingleTaskGP_log_unordered_categorical_parameters(self) -> None: logs = [] ax_client = AxClient(random_seed=0) - params = [ - { - "name": f"x{i + 1}", - "type": "range", - "bounds": [*Branin._domain[i]], - "value_type": "float", - "log_scale": False, - } - for i in range(2) - ] + params = cast( + list[dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]]], + [ + { + "name": f"x{i + 1}", + "type": "range", + "bounds": [*Branin._domain[i]], + "value_type": "float", + "log_scale": False, + } + for i in range(2) + ], + ) with mock.patch( "ax.generation_strategy.dispatch_utils.logger.info", @@ -3189,11 +3175,6 @@ def test_SingleTaskGP_log_unordered_categorical_parameters(self) -> None: ): ax_client.create_experiment( name="branin_test_experiment", - # pyre-fixme[6]: for argument `parameters`, expected - # `List[Dict[str, Union[None, Dict[str, List[str]], - # Sequence[Union[None, bool, float, int, str]], - # bool, float, int, str]]]` - # but got `List[Dict[str, Union[List[int], bool, str]]]` parameters=params, objectives={"branin": ObjectiveProperties(minimize=True)}, ) @@ -3267,34 +3248,33 @@ def _set_up_client_for_get_model_predictions_no_next_trial() -> AxClient: return ax_client -# pyre-fixme[2]: Parameter must be annotated. -def _attach_completed_trials(ax_client) -> None: +def _attach_completed_trials(ax_client: AxClient) -> None: # Attach completed trials - trial1 = {"x1": 0.1, "x2": 0.1} + trial1: TParameterization = {"x1": 0.1, "x2": 0.1} parameters, trial_index = ax_client.attach_trial(trial1) ax_client.complete_trial( trial_index=trial_index, raw_data=_evaluate_test_metrics(parameters) ) - trial2 = {"x1": 0.2, "x2": 0.1} + trial2: TParameterization = {"x1": 0.2, "x2": 0.1} parameters, trial_index = ax_client.attach_trial(trial2) ax_client.complete_trial( trial_index=trial_index, raw_data=_evaluate_test_metrics(parameters) ) -# pyre-fixme[2]: Parameter must be annotated. -def _attach_not_completed_trials(ax_client) -> None: +def _attach_not_completed_trials(ax_client: AxClient) -> None: # Attach not yet completed trials - trial3 = {"x1": 0.3, "x2": 0.1} + trial3: TParameterization = {"x1": 0.3, "x2": 0.1} parameters, trial_index = ax_client.attach_trial(trial3) - trial4 = {"x1": 0.4, "x2": 0.1} + trial4: TParameterization = {"x1": 0.4, "x2": 0.1} parameters, trial_index = ax_client.attach_trial(trial4) # Test metric evaluation method -# pyre-fixme[2]: Parameter must be annotated. -def _evaluate_test_metrics(parameters) -> dict[str, tuple[float, float]]: +def _evaluate_test_metrics( + parameters: TParameterization, +) -> dict[str, tuple[float, float]]: x = np.array([parameters.get(f"x{i + 1}") for i in range(2)]) return {"test_metric1": (x[0] / x[1], 0.0), "test_metric2": (x[0] + x[1], 0.0)} diff --git a/ax/service/tests/test_global_stopping.py b/ax/service/tests/test_global_stopping.py index 18e4dc201c0..8e6703067d9 100644 --- a/ax/service/tests/test_global_stopping.py +++ b/ax/service/tests/test_global_stopping.py @@ -48,9 +48,7 @@ def get_ax_client_for_branin( def evaluate(self, parameters: TParameterization) -> dict[str, tuple[float, float]]: """Evaluates the parameters for branin experiment.""" x = np.array([parameters.get(f"x{i + 1}") for i in range(2)]) - # pyre-fixme[7]: Expected `Dict[str, Tuple[float, float]]` but got - # `Dict[str, Tuple[Union[float, ndarray], float]]`. - return {"branin": (branin(x), 0.0)} + return {"branin": (float(branin(x)), 0.0)} def test_global_stopping_integration(self) -> None: """ @@ -69,7 +67,6 @@ def test_global_stopping_integration(self) -> None: parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, - # pyre-fixme[6]: For 2nd param expected `Union[Dict[str, Union[Tuple[... raw_data=self.evaluate(parameters), ) @@ -109,7 +106,6 @@ def test_min_trials(self) -> None: parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, - # pyre-fixme[6]: For 2nd param expected `Union[Dict[str, Union[Tuple[... raw_data=self.evaluate(parameters), ) @@ -117,7 +113,6 @@ def test_min_trials(self) -> None: parameters, trial_index = ax_client.get_next_trial() ax_client.complete_trial( trial_index=trial_index, - # pyre-fixme[6]: For 2nd param expected `Union[Dict[str, Union[Tuple[Unio... raw_data=self.evaluate(parameters), ) self.assertIsNotNone(parameters) diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index b0eb544c61a..800be9ab2e5 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -6,6 +6,7 @@ # pyre-strict +from collections.abc import Sequence from typing import Any from ax.core.metric import Metric @@ -16,6 +17,7 @@ ParameterType, RangeParameter, ) +from ax.core.types import TParamValue from ax.runners.synthetic import SyntheticRunner from ax.service.utils.instantiation import InstantiationBase from ax.utils.common.testutils import TestCase @@ -332,7 +334,9 @@ def test_choice_with_is_sorted(self) -> None: _ = InstantiationBase.parameter_from_json(representation) def test_hss(self) -> None: - parameter_dicts = [ + parameter_dicts: list[ + dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]] + ] = [ { "name": "root", "type": "fixed", @@ -361,7 +365,6 @@ def test_hss(self) -> None: {"name": "another_int", "type": "fixed", "value": "2"}, ] search_space = InstantiationBase.make_search_space( - # pyre-fixme[6]: For 1st param expected `List[Dict[str, Union[None, Dict[... parameters=parameter_dicts, parameter_constraints=[], ) diff --git a/ax/service/tests/test_interactive_loop.py b/ax/service/tests/test_interactive_loop.py index a7ceba4af73..b417fd58426 100644 --- a/ax/service/tests/test_interactive_loop.py +++ b/ax/service/tests/test_interactive_loop.py @@ -9,13 +9,15 @@ import functools import time +from collections.abc import Callable, Sequence from logging import WARN from queue import Queue from threading import Event, Lock +from typing import cast import numpy as np from ax.adapter.registry import Generators -from ax.core.types import TEvaluationOutcome, TParameterization +from ax.core.types import TEvaluationOutcome, TParameterization, TParamValue from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, @@ -41,10 +43,9 @@ def setUp(self) -> None: ] ) self.ax_client = AxClient(generation_strategy=generation_strategy) - self.ax_client.create_experiment( - name="hartmann_test_experiment", - # pyre-fixme[6] - parameters=[ + parameters = cast( + list[dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]]], + [ { "name": f"x{i}", "type": "range", @@ -52,6 +53,10 @@ def setUp(self) -> None: } for i in range(1, 7) ], + ) + self.ax_client.create_experiment( + name="hartmann_test_experiment", + parameters=parameters, objectives={"hartmann6": ObjectiveProperties(minimize=True)}, tracking_metric_names=["l2norm"], ) @@ -76,8 +81,10 @@ def test_interactive_loop(self) -> None: ax_client=self.ax_client, num_trials=15, candidate_queue_maxsize=3, - # pyre-fixme[6] - elicitation_function=self._elicit, + elicitation_function=cast( + Callable[[tuple[TParameterization, int]], TEvaluationOutcome], + self._elicit, + ), ) self.assertTrue(optimization_completed) @@ -94,8 +101,10 @@ def _aborted_elicit( ax_client=self.ax_client, num_trials=15, candidate_queue_maxsize=3, - # pyre-fixme[6] - elicitation_function=_aborted_elicit, + elicitation_function=cast( + Callable[[tuple[TParameterization, int]], TEvaluationOutcome], + _aborted_elicit, + ), ) self.assertFalse(optimization_completed) @@ -144,8 +153,10 @@ def _sleep_elicit( ax_client=self.ax_client, num_trials=3, candidate_queue_maxsize=3, - # pyre-fixme[6] - elicitation_function=_sleep_elicit, + elicitation_function=cast( + Callable[[tuple[TParameterization, int]], TEvaluationOutcome], + _sleep_elicit, + ), ) # Assert sleep and retry warning is somewhere in the logs diff --git a/ax/service/tests/test_managed_loop.py b/ax/service/tests/test_managed_loop.py index 48f0d485685..a403a0c684d 100644 --- a/ax/service/tests/test_managed_loop.py +++ b/ax/service/tests/test_managed_loop.py @@ -6,63 +6,63 @@ # pyre-strict +from typing import Any from unittest.mock import Mock, patch import numpy as np -import numpy.typing as npt from ax.adapter.registry import Generators +from ax.core.types import TParameterization from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, ) +from ax.generators.random.sobol import SobolGenerator from ax.metrics.branin import branin from ax.service.managed_loop import OptimizationLoop, optimize from ax.utils.common.testutils import TestCase from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import assert_is_instance, none_throws def _branin_evaluation_function( - # pyre-fixme[2]: Parameter must be annotated. - parameterization, - weight=None, # pyre-fixme[2]: Parameter must be annotated. -) -> dict[str, tuple[float | npt.NDArray, float]]: + parameterization: TParameterization, + weight: float | None = None, +) -> dict[str, tuple[float, float]]: if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") - x1, x2 = parameterization["x1"], parameterization["x2"] + x1, x2 = float(parameterization["x1"]), float(parameterization["x2"]) return { - "branin": (branin(x1, x2), 0.0), - "constrained_metric": (-branin(x1, x2), 0.0), + "branin": (float(branin(x1, x2)), 0.0), + "constrained_metric": (float(-branin(x1, x2)), 0.0), } def _branin_evaluation_function_v2( - # pyre-fixme[2]: Parameter must be annotated. - parameterization, - weight=None, # pyre-fixme[2]: Parameter must be annotated. -) -> tuple[float | npt.NDArray, float]: + parameterization: TParameterization, + weight: float | None = None, +) -> tuple[float, float]: if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") - x1, x2 = parameterization["x1"], parameterization["x2"] - return (branin(x1, x2), 0.0) + x1, x2 = float(parameterization["x1"]), float(parameterization["x2"]) + return (float(branin(x1, x2)), 0.0) def _branin_evaluation_function_with_unknown_sem( - # pyre-fixme[2]: Parameter must be annotated. - parameterization, - weight=None, # pyre-fixme[2]: Parameter must be annotated. -) -> tuple[float | npt.NDArray, None]: + parameterization: TParameterization, + weight: float | None = None, +) -> tuple[float, None]: if any(param_name not in parameterization.keys() for param_name in ["x1", "x2"]): raise ValueError("Parametrization does not contain x1 or x2") - x1, x2 = parameterization["x1"], parameterization["x2"] - return (branin(x1, x2), None) + x1, x2 = float(parameterization["x1"]), float(parameterization["x2"]) + return (float(branin(x1, x2)), None) class TestManagedLoop(TestCase): """Check functionality of optimization loop.""" def test_with_evaluation_function_propagates_parameter_constraints(self) -> None: - kwargs = { + kwargs: dict[str, Any] = { "parameters": [ { "name": "x1", @@ -151,9 +151,7 @@ def test_branin_with_active_parameter_constraints(self) -> None: bp, _ = loop.full_run().get_best_point() self.assertIn("x1", bp) self.assertIn("x2", bp) - # pyre-fixme[58]: `+` is not supported for operand types `Union[None, bool, - # float, int, str]` and `Union[None, bool, float, int, str]`. - self.assertLessEqual(bp["x1"] + bp["x2"], 1.0 + 1e-8) + self.assertLessEqual(float(bp["x1"]) + float(bp["x2"]), 1.0 + 1e-8) with self.assertRaisesRegex(ValueError, "Optimization is complete"): loop.run_trial() @@ -241,11 +239,8 @@ def test_branin_batch(self) -> None: self.assertIn("x2", bp) assert vals is not None self.assertIn("branin", vals[0]) - # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any], - # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`. - self.assertIn("branin", vals[1]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. - self.assertIn("branin", vals[1]["branin"]) + self.assertIn("branin", none_throws(vals[1])) + self.assertIn("branin", none_throws(vals[1])["branin"]) # Check that all total_trials * arms_per_trial * 2 metrics evaluations # are present in the dataframe. self.assertEqual(len(loop.experiment.fetch_data().df.index), 12) @@ -270,11 +265,8 @@ def test_optimize(self) -> None: self.assertIn("x2", best) assert vals is not None self.assertIn("objective", vals[0]) - # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any], - # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`. - self.assertIn("objective", vals[1]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. - self.assertIn("objective", vals[1]["objective"]) + self.assertIn("objective", none_throws(vals[1])) + self.assertIn("objective", none_throws(vals[1])["objective"]) @patch( "ax.service.managed_loop." @@ -301,11 +293,8 @@ def test_optimize_with_predictions(self, _) -> None: self.assertIn("x2", best) assert vals is not None self.assertIn("a", vals[0]) - # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any], - # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`. - self.assertIn("a", vals[1]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. - self.assertIn("a", vals[1]["a"]) + self.assertIn("a", none_throws(vals[1])) + self.assertIn("a", none_throws(vals[1])["a"]) @mock_botorch_optimize def test_optimize_unknown_sem(self) -> None: @@ -327,11 +316,8 @@ def test_optimize_unknown_sem(self) -> None: self.assertIn("x2", best) self.assertIsNotNone(vals) self.assertIn("objective", vals[0]) - # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any], - # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`. - self.assertIn("objective", vals[1]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. - self.assertIn("objective", vals[1]["objective"]) + self.assertIn("objective", none_throws(vals[1])) + self.assertIn("objective", none_throws(vals[1])["objective"]) def test_optimize_propagates_random_seed(self) -> None: """Tests optimization as a single call.""" @@ -347,8 +333,8 @@ def test_optimize_propagates_random_seed(self) -> None: total_trials=5, random_seed=12345, ) - # pyre-fixme[16]: Optional type has no attribute `model`. - self.assertEqual(12345, model.generator.seed) + generator = assert_is_instance(none_throws(model).generator, SobolGenerator) + self.assertEqual(12345, generator.seed) def test_optimize_search_space_exhausted(self) -> None: """Tests optimization as a single call.""" @@ -370,11 +356,8 @@ def test_optimize_search_space_exhausted(self) -> None: self.assertIn("x2", best) self.assertIsNotNone(vals) self.assertIn("objective", vals[0]) - # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any], - # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`. - self.assertIn("objective", vals[1]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. - self.assertIn("objective", vals[1]["objective"]) + self.assertIn("objective", none_throws(vals[1])) + self.assertIn("objective", none_throws(vals[1])["objective"]) def test_custom_gs(self) -> None: """Managed loop with custom generation strategy""" @@ -432,18 +415,14 @@ def test_optimize_graceful_exit_on_exception(self) -> None: self.assertIn("x2", best) self.assertIsNotNone(vals) self.assertIn("objective", vals[0]) - # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any], - # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`. - self.assertIn("objective", vals[1]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. - self.assertIn("objective", vals[1]["objective"]) + self.assertIn("objective", none_throws(vals[1])) + self.assertIn("objective", none_throws(vals[1])["objective"]) @patch( "ax.core.experiment.Experiment.new_trial", side_effect=RuntimeError("cholesky_cpu error - bad matrix"), ) - # pyre-fixme[3]: Return type must be annotated. - def test_annotate_exception(self, _): + def test_annotate_exception(self, _: Mock) -> None: strategy0 = GenerationStrategy( name="Sobol", steps=[GenerationStep(generator=Generators.SOBOL, num_trials=-1)], diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index e857215be50..4eab606a5ef 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -9,7 +9,9 @@ import itertools import logging from collections import namedtuple +from collections.abc import Callable from logging import DEBUG, INFO, WARN +from typing import Any from unittest import mock from unittest.mock import patch @@ -265,8 +267,7 @@ def test_exp_to_df(self) -> None: with patch.object(Experiment, "lookup_data", lambda self: mock_results): df = exp_to_df(exp=exp) # all but two rows should have a metric value of NaN - # pyre-fixme[16]: `bool` has no attribute `sum`. - self.assertEqual(pd.isna(df[OBJECTIVE_NAME]).sum(), len(df.index) - 2) + self.assertEqual(df[OBJECTIVE_NAME].isna().sum(), len(df.index) - 2) # an experiment with more results than arms raises an error with ( @@ -369,16 +370,16 @@ def test_get_standard_plots(self) -> None: self.assertTrue(all(isinstance(plot, go.Figure) for plot in plots)) # Raise an exception in one plot and make sure we generate the others - for plot_function, num_expected_plots in [ - [_get_curve_plot_dropdown, 8], # Not used - [_get_objective_trace_plot, 6], - [_objective_vs_true_objective_scatter, 7], - [_get_objective_v_param_plots, 6], - [_get_cross_validation_plots, 7], - [plot_feature_importance_by_feature_plotly, 6], - ]: + plot_test_cases: list[tuple[Callable[..., Any], int]] = [ + (_get_curve_plot_dropdown, 8), # Not used + (_get_objective_trace_plot, 6), + (_objective_vs_true_objective_scatter, 7), + (_get_objective_v_param_plots, 6), + (_get_cross_validation_plots, 7), + (plot_feature_importance_by_feature_plotly, 6), + ] + for plot_function, num_expected_plots in plot_test_cases: with mock.patch( - # pyre-ignore f"ax.service.utils.report_utils.{plot_function.__name__}", side_effect=Exception(), ): diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 2f2d8d82bda..80b120825f8 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -11,8 +11,10 @@ import os import tempfile from collections import OrderedDict +from collections.abc import Callable from functools import partial from math import nan +from typing import Any import numpy as np import pandas as pd @@ -51,6 +53,11 @@ from ax.generation_strategy.center_generation_node import CenterGenerationNode from ax.generation_strategy.generation_node import GenerationNode, GenerationStep from ax.generation_strategy.generator_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import ( + MaxGenerationParallelism, + MaxTrialsAwaitingData, + MinTrials, +) from ax.generators.torch.botorch_modular.kernels import ScaleMaternKernel from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.generators.torch.botorch_modular.utils import ModelConfig @@ -181,11 +188,10 @@ from botorch.models.transforms.outcome import Standardize from botorch.sampling.normal import SobolQMCNormalSampler from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws -# pyre-fixme[5]: Global expression must be annotated. -TEST_CASES = [ +TEST_CASES: list[tuple[str, Callable[..., Any]]] = [ ("AbandonedArm", get_abandoned_arm), ( "AdditiveMapSaasSingleTaskGP", @@ -1280,9 +1286,9 @@ def test_block_gen_if_met_migration(self) -> None: node = generation_node_from_json(json) self.assertEqual(len(node.transition_criteria), 0) self.assertEqual(len(node.pausing_criteria), 1) - blocking = node.pausing_criteria[0] - self.assertEqual(blocking.__class__.__name__, "MaxGenerationParallelism") - # pyre-ignore[16]: Attribute exists on MaxGenerationParallelism + blocking = assert_is_instance( + node.pausing_criteria[0], MaxGenerationParallelism + ) self.assertEqual(blocking.threshold, 3) with self.subTest("MinTrials_with_block_gen_if_met_only"): @@ -1319,8 +1325,9 @@ def test_block_gen_if_met_migration(self) -> None: node = generation_node_from_json(json) self.assertEqual(len(node.transition_criteria), 0) self.assertEqual(len(node.pausing_criteria), 1) - blocking = node.pausing_criteria[0] - self.assertEqual(blocking.__class__.__name__, "MaxTrialsAwaitingData") + blocking = assert_is_instance( + node.pausing_criteria[0], MaxTrialsAwaitingData + ) self.assertEqual(blocking.threshold, 5) with self.subTest("MinTrials_with_block_gen_if_met_and_block_transition"): @@ -1357,16 +1364,13 @@ def test_block_gen_if_met_migration(self) -> None: # Should have both self.assertEqual(len(node.transition_criteria), 1) self.assertEqual(len(node.pausing_criteria), 1) - tc = node.transition_criteria[0] - self.assertEqual(tc.__class__.__name__, "MinTrials") - # pyre-ignore[16]: Attribute exists on MinTrials + tc = assert_is_instance(node.transition_criteria[0], MinTrials) self.assertEqual(tc.threshold, 5) self.assertEqual(tc.transition_to, "next_node") - blocking = node.pausing_criteria[0] - self.assertEqual(blocking.__class__.__name__, "MaxTrialsAwaitingData") - # pyre-ignore[16]: threshold exists on MaxTrialsAwaitingData + blocking = assert_is_instance( + node.pausing_criteria[0], MaxTrialsAwaitingData + ) self.assertEqual(blocking.threshold, 5) - # pyre-ignore[16]: use_all_trials_in_exp exists on MaxTrialsAwaitingData self.assertTrue(blocking.use_all_trials_in_exp) def test_SobolQMCNormalSampler(self) -> None: diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index b3690d58b46..83878d96a0e 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -35,6 +35,7 @@ from ax.core.experiment_status import ExperimentStatus from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric +from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -209,9 +210,8 @@ def test_connection_to_db_with_url(self) -> None: init_engine_and_session_factory(url="sqlite://", force_init=True) def MockDBAPI(self) -> MagicMock: - connection = Mock() + connection: Mock = Mock() - # pyre-fixme[53]: Captured variable `connection` is not annotated. def connect(*args: Any, **kwargs: Any) -> Mock: return connection @@ -447,10 +447,10 @@ def test_saving_and_loading_experiment_with_aux_exp(self) -> None: tracking_metrics=[Metric(name="tracking")], is_test=True, auxiliary_experiments_by_purpose={ - # pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute - self.config.auxiliary_experiment_purpose_enum.PE_EXPERIMENT: [ - AuxiliaryExperiment(experiment=aux_experiment) - ] + cast( + type[AuxiliaryExperimentPurpose], + self.config.auxiliary_experiment_purpose_enum, + ).PE_EXPERIMENT: [AuxiliaryExperiment(experiment=aux_experiment)] }, ) self.assertIsNone(experiment_w_aux_exp.db_id) @@ -474,8 +474,10 @@ def test_saving_and_loading_experiment_with_aux_exp_reduced_state(self) -> None: aux_exp_gs = get_generation_strategy() aux_exp.new_trial(aux_exp_gs.gen_single_trial(experiment=aux_exp)) save_experiment(aux_exp, config=self.config) - # pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute - purpose = self.config.auxiliary_experiment_purpose_enum.PE_EXPERIMENT + purpose = cast( + type[AuxiliaryExperimentPurpose], + self.config.auxiliary_experiment_purpose_enum, + ).PE_EXPERIMENT target_exp = Experiment( name="test_experiment_w_aux_exp_in_SQAStoreTest_reduced_state", @@ -531,10 +533,10 @@ def test_saving_with_aux_exp_not_in_db(self) -> None: search_space=get_search_space(), is_test=True, auxiliary_experiments_by_purpose={ - # pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute - self.config.auxiliary_experiment_purpose_enum.PE_EXPERIMENT: [ - AuxiliaryExperiment(experiment=aux_experiment) - ] + cast( + type[AuxiliaryExperimentPurpose], + self.config.auxiliary_experiment_purpose_enum, + ).PE_EXPERIMENT: [AuxiliaryExperiment(experiment=aux_experiment)] }, ) with self.assertRaisesRegex(SQAEncodeError, "that does not exist in"): @@ -545,8 +547,10 @@ def test_saving_and_loading_experiment_with_cross_referencing_aux_exp( ) -> None: exp1_name = "test_aux_exp_in_SQAStoreTest1" exp2_name = "test_aux_exp_in_SQAStoreTest2" - # pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute - exp_purpose = self.config.auxiliary_experiment_purpose_enum.PE_EXPERIMENT + exp_purpose = cast( + type[AuxiliaryExperimentPurpose], + self.config.auxiliary_experiment_purpose_enum, + ).PE_EXPERIMENT exp1 = Experiment( name=exp1_name, @@ -796,7 +800,10 @@ def test_load_experiment_skip_metrics_and_runners(self) -> None: side_effect=Decoder(SQAConfig()).experiment_from_sqa, ) def test_experiment_save_and_load_reduced_state( - self, _mock_exp_from_sqa, _mock_trial_from_sqa, _mock_gr_from_sqa + self, + _mock_exp_from_sqa: Mock, + _mock_trial_from_sqa: Mock, + _mock_gr_from_sqa: Mock, ) -> None: for skip_runners_and_metrics in [False, True]: # 1. No abandoned arms + no trials case, reduced state should be the @@ -924,29 +931,28 @@ def wrapper(*args: Any, **kwargs: Any) -> T: def test_mt_experiment_save_and_load(self) -> None: experiment = get_multi_type_experiment(add_trials=True) save_experiment(experiment) - loaded_experiment = load_experiment(experiment.name) + loaded_experiment = assert_is_instance( + load_experiment(experiment.name), MultiTypeExperiment + ) self.assertEqual(loaded_experiment.default_trial_type, "type1") self.assertEqual(len(loaded_experiment._trial_type_to_runner), 2) - # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) def test_mt_experiment_save_and_load_skip_runners_and_metrics(self) -> None: experiment = get_multi_type_experiment(add_trials=True) save_experiment(experiment) - loaded_experiment = load_experiment( - experiment.name, skip_runners_and_metrics=True + loaded_experiment = assert_is_instance( + load_experiment(experiment.name, skip_runners_and_metrics=True), + MultiTypeExperiment, ) self.assertEqual(loaded_experiment.default_trial_type, "type1") self.assertIsNone(loaded_experiment._trial_type_to_runner["type1"]) self.assertIsNone(loaded_experiment._trial_type_to_runner["type2"]) - # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) @@ -1640,10 +1646,7 @@ def test_experiment_objective_threshold_updates(self) -> None: experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual(get_session().query(SQAMetric).count(), 7) - self.assertIsNotNone( - # pyre-fixme[16]: Optional type has no attribute `objective_thresholds`. - experiment.optimization_config.objective_thresholds[0].metric.db_id - ) + self.assertIsNotNone(optimization_config.objective_thresholds[0].metric.db_id) # add outcome constraint outcome_constraint2 = OutcomeConstraint( @@ -2197,8 +2200,7 @@ def test_encode_decode_generation_strategy_base_case(self) -> None: save_generation_strategy(generation_strategy=generation_strategy) # Also try restoring this generation strategy by its ID in the DB. new_generation_strategy = load_generation_strategy_by_id( - # pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`. - gs_id=generation_strategy._db_id + gs_id=none_throws(generation_strategy._db_id) ) # Some fields of the reloaded GS are not expected to be set (both will be # set during next model fitting call), so we unset them on the original GS as @@ -2255,8 +2257,7 @@ def test_encode_decode_generation_node_gs_with_advanced_settings(self) -> None: # Try restoring this generation strategy by its ID in the DB. save_generation_strategy(generation_strategy=generation_strategy) new_generation_strategy = load_generation_strategy_by_id( - # pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`. - gs_id=generation_strategy._db_id + gs_id=none_throws(generation_strategy._db_id) ) # Some fields of the reloaded GS are not expected to be set (both will be @@ -2701,13 +2702,12 @@ def test_update_runner(self) -> None: encoder=self.encoder, decoder=self.decoder, ) - # pyre-fixme[16]: Optional type has no attribute `db_id`. - self.assertIsNone(experiment.runner.db_id) - self.assertIsNotNone(experiment.runner) - # pyre-fixme[16]: `Runner` has no attribute `dummy_metadata`. - self.assertIsNone(experiment.runner.dummy_metadata) + runner = none_throws(experiment.runner) + self.assertIsNone(runner.db_id) + self.assertIsNotNone(runner) + self.assertIsNone(assert_is_instance(runner, SyntheticRunner).dummy_metadata) save_experiment(experiment=experiment) - old_runner_db_id = experiment.runner.db_id + old_runner_db_id = runner.db_id self.assertIsNotNone(old_runner_db_id) new_runner = get_synthetic_runner() # pyre-fixme[8]: Attribute has type `Optional[str]`; used as `Dict[str, str]`. @@ -2721,9 +2721,9 @@ def test_update_runner(self) -> None: decoder=self.decoder, ) self.assertIsNotNone(new_runner.db_id) # New runner should be added to DB. - self.assertEqual(experiment.runner.db_id, new_runner.db_id) + self.assertEqual(none_throws(experiment.runner).db_id, new_runner.db_id) loaded_experiment = load_experiment(experiment_name=experiment.name) - self.assertEqual(loaded_experiment.runner.db_id, new_runner.db_id) + self.assertEqual(none_throws(loaded_experiment.runner).db_id, new_runner.db_id) def test_experiment_validation(self) -> None: exp = get_experiment() @@ -2810,9 +2810,9 @@ def test_get_immutable_search_space_and_opt_config(self) -> None: ) def test_immutable_search_space_and_opt_config_loading( self, - _mock_get_exp_sqa_imm_oc_ss, - _mock_get_gs_sqa_imm_oc_ss, - _mock_gr_from_sqa, + _mock_get_exp_sqa_imm_oc_ss: Mock, + _mock_get_gs_sqa_imm_oc_ss: Mock, + _mock_gr_from_sqa: Mock, ) -> None: experiment = get_experiment_with_batch_trial(constrain_search_space=False) experiment._properties = {Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True} @@ -2932,8 +2932,7 @@ def test_generator_run_validated_fields(self) -> None: # loaded are non-null. save_experiment(exp) loaded_exp = load_experiment(exp.name) - # pyre-fixme[16]: Optional type has no attribute `generator_run`. - loaded_gr = loaded_exp.trials.get(0).generator_run + loaded_gr = assert_is_instance(loaded_exp.trials[0], Trial).generator_run for instrumented_attr in GR_LARGE_MODEL_ATTRS: python_attr = SQA_COL_TO_GR_ATTR[instrumented_attr.key] self.assertIsNotNone(getattr(loaded_gr, f"_{python_attr}")) @@ -2948,7 +2947,9 @@ def test_generator_run_validated_fields(self) -> None: # was not propagated to the DB. save_experiment(loaded_exp) newly_loaded_exp = load_experiment(exp.name) - newly_loaded_gr = newly_loaded_exp.trials.get(0).generator_run + newly_loaded_gr = assert_is_instance( + newly_loaded_exp.trials[0], Trial + ).generator_run for instrumented_attr in GR_LARGE_MODEL_ATTRS: python_attr = SQA_COL_TO_GR_ATTR[instrumented_attr.key] self.assertIsNotNone(getattr(newly_loaded_gr, f"_{python_attr}")) diff --git a/ax/storage/sqa_store/tests/test_with_db_settings_base.py b/ax/storage/sqa_store/tests/test_with_db_settings_base.py index 1a460c951d7..52f85f0c0e7 100644 --- a/ax/storage/sqa_store/tests/test_with_db_settings_base.py +++ b/ax/storage/sqa_store/tests/test_with_db_settings_base.py @@ -10,7 +10,9 @@ import string from unittest.mock import patch +from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment +from ax.core.trial import Trial from ax.core.trial_status import TrialStatus from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.sqa_store.db import init_test_engine_and_session_factory @@ -32,6 +34,7 @@ from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import DEFAULT_USER, get_experiment, get_generator_run from ax.utils.testing.modeling_stubs import get_generation_strategy +from pyre_extensions import assert_is_instance class TestWithDBSettingsBase(TestCase): @@ -260,8 +263,7 @@ def test_updated_trials_mini_batch(self) -> None: experiment.name, decoder=self.with_db_settings.db_settings.decoder ) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `status`. - loaded_experiment.trials.get(trial.index).status, + loaded_experiment.trials[trial.index].status, TrialStatus.CANDIDATE, ) self.assertIsNotNone(trial.db_id) @@ -294,17 +296,15 @@ def test_update_reduced_state_generator_runs(self) -> None: save_generation_strategy=True ) - trials = [experiment.new_trial() for _ in range(5)] + trials: list[BaseTrial] = [experiment.new_trial() for _ in range(5)] grs = [] for t in trials: gr = generation_strategy.gen_single_trial(experiment) grs.append(gr) - t.add_generator_run(gr) + assert_is_instance(t, Trial).add_generator_run(gr) self.with_db_settings._save_or_update_trials_and_generation_strategy_if_possible( # noqa E501 experiment=experiment, - # pyre-fixme[6]: For 2nd param expected `List[BaseTrial]` but got - # `List[Trial]`. trials=trials, generation_strategy=generation_strategy, new_generator_runs=grs, @@ -320,11 +320,11 @@ def test_update_reduced_state_generator_runs(self) -> None: for attr in GR_LARGE_MODEL_ATTRS: # Map SQA column name to Python attribute name python_attr_name = f"_{SQA_COL_TO_GR_ATTR[attr.key]}" + t = assert_is_instance(trial, Trial) if idx < len(loaded_experiment.trials) - 1: - # pyre-fixme[16]: `BaseTrial` has no attribute `generator_run`. - self.assertIsNone(getattr(trial.generator_run, python_attr_name)) + self.assertIsNone(getattr(t.generator_run, python_attr_name)) else: - self.assertIsNotNone(getattr(trial.generator_run, python_attr_name)) + self.assertIsNotNone(getattr(t.generator_run, python_attr_name)) loaded_generation_strategy = _load_generation_strategy_by_experiment_name( experiment.name, decoder=self.with_db_settings.db_settings.decoder diff --git a/ax/utils/common/tests/test_docutils.py b/ax/utils/common/tests/test_docutils.py index c86e3aadebc..b7cc8fb4756 100644 --- a/ax/utils/common/tests/test_docutils.py +++ b/ax/utils/common/tests/test_docutils.py @@ -21,8 +21,7 @@ def has_no_doc() -> None: class TestDocUtils(TestCase): def test_transfer_doc(self) -> None: @copy_doc(has_doc) - # pyre-fixme[3]: Return type must be annotated. - def inherits_doc(): + def inherits_doc() -> None: pass self.assertEqual(inherits_doc.__doc__, "I have a docstring") @@ -31,8 +30,7 @@ def test_fail_when_already_has_doc(self) -> None: with self.assertRaises(ValueError): @copy_doc(has_doc) - # pyre-fixme[3]: Return type must be annotated. - def inherits_doc(): + def inherits_doc() -> None: """I already have a doc string""" pass @@ -40,6 +38,5 @@ def test_fail_when_no_doc_to_copy(self) -> None: with self.assertRaises(ValueError): @copy_doc(has_no_doc) - # pyre-fixme[3]: Return type must be annotated. - def f(): + def f() -> None: pass diff --git a/ax/utils/common/tests/test_equality.py b/ax/utils/common/tests/test_equality.py index 12cba4d3dd2..18e23c1d741 100644 --- a/ax/utils/common/tests/test_equality.py +++ b/ax/utils/common/tests/test_equality.py @@ -24,9 +24,7 @@ class EqualityTest(TestCase): def test_EqualityTypechecker(self) -> None: @equality_typechecker - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def eq(x, y): + def eq(x: object, y: object) -> bool: return x == y self.assertFalse(eq(5, 5.0)) diff --git a/ax/utils/common/tests/test_executils.py b/ax/utils/common/tests/test_executils.py index 0ea878c3eea..738490ddd5a 100644 --- a/ax/utils/common/tests/test_executils.py +++ b/ax/utils/common/tests/test_executils.py @@ -262,12 +262,10 @@ def test_on_function_with_wrapper_message(self) -> None: instance methods. """ - mock = Mock() + mock: Mock = Mock() @retry_on_exception(wrap_error_message_in="Wrapper error message") - # pyre-fixme[53]: Captured variable `mock` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - def error_throwing_function(): + def error_throwing_function() -> None: mock() raise RuntimeError("I failed") diff --git a/ax/utils/common/tests/test_testutils.py b/ax/utils/common/tests/test_testutils.py index 5c728fdb71d..e20265f6e5d 100644 --- a/ax/utils/common/tests/test_testutils.py +++ b/ax/utils/common/tests/test_testutils.py @@ -17,13 +17,12 @@ from botorch.models.gp_regression import SingleTaskGP -# pyre-fixme[3]: Return type must be annotated. -def _f(): +def _f() -> None: e = RuntimeError("Test") raise e -F_FAILURE_LINENO = 23 # Line # for the error in `_f`. +F_FAILURE_LINENO = 22 # Line # for the error in `_f`. def _g() -> None: diff --git a/ax/utils/testing/tests/test_backend_simulator.py b/ax/utils/testing/tests/test_backend_simulator.py index f393d9f4e3f..9b56b9233d7 100644 --- a/ax/utils/testing/tests/test_backend_simulator.py +++ b/ax/utils/testing/tests/test_backend_simulator.py @@ -13,6 +13,7 @@ from ax.utils.common.testutils import TestCase from ax.utils.testing.backend_simulator import BackendSimulator, BackendSimulatorOptions from ax.utils.testing.utils_testing_stubs import get_backend_simulator_with_trials +from pyre_extensions import none_throws class BackendSimulatorTest(TestCase): @@ -123,8 +124,7 @@ def test_backend_simulator_internal_clock(self) -> None: sim.lookup_trial_index_status(trial_index=2), TrialStatus.COMPLETED ) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `sim_completed_time`. - sim.get_sim_trial_by_index(trial_index=2).sim_completed_time, + none_throws(sim.get_sim_trial_by_index(trial_index=2)).sim_completed_time, 2.0, ) with self.assertRaisesRegex(ValueError, "Trial 100 not found in simulator"): diff --git a/ax/utils/testing/tests/test_mock.py b/ax/utils/testing/tests/test_mock.py index a2d8471de6f..cde1630037f 100644 --- a/ax/utils/testing/tests/test_mock.py +++ b/ax/utils/testing/tests/test_mock.py @@ -40,7 +40,7 @@ def test_botorch_mocks(self) -> None: with mock_botorch_optimize_context_manager(): gen_candidates_scipy( initial_conditions=torch.tensor([[0.0]]), - acquisition_function=MockAcquisitionFunction(), # pyre-ignore [6] + acquisition_function=MockAcquisitionFunction(), # pyre-ignore[6] ) def test_fully_bayesian_mocks(self) -> None: