diff --git a/ax/adapter/tests/test_base_adapter.py b/ax/adapter/tests/test_base_adapter.py index 97716a13993..6e4e5eed757 100644 --- a/ax/adapter/tests/test_base_adapter.py +++ b/ax/adapter/tests/test_base_adapter.py @@ -76,7 +76,7 @@ from botorch.exceptions.warnings import InputDataWarning from botorch.models.utils.assorted import validate_input_scaling from pandas.testing import assert_frame_equal -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws ADAPTER__GEN_PATH: str = "ax.adapter.base.Adapter._gen" @@ -908,8 +908,14 @@ def test_set_model_space(self) -> None: .index.get_level_values("arm_name") ) self.assertEqual(set(ood_arms), {"status_quo", "custom"}) - self.assertEqual(m.model_space.parameters["x1"].lower, -5.0) # pyre-ignore[16] - self.assertEqual(m.model_space.parameters["x2"].upper, 15.0) # pyre-ignore[16] + self.assertEqual( + assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower, + -5.0, + ) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 15.0, + ) self.assertEqual(len(m.model_space.parameter_constraints), 1) # With expand model space, custom is not OOD, and model space is expanded @@ -925,8 +931,14 @@ def test_set_model_space(self) -> None: .index.get_level_values("arm_name") ) self.assertEqual(set(ood_arms), {"status_quo"}) - self.assertEqual(m.model_space.parameters["x1"].lower, -20.0) - self.assertEqual(m.model_space.parameters["x2"].upper, 18.0) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower, + -20.0, + ) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 18.0, + ) self.assertEqual(m.model_space.parameter_constraints, []) # With fill values, SQ is also in design, and x2 is further expanded @@ -941,7 +953,10 @@ def test_set_model_space(self) -> None: transform_configs={"FillMissingParameters": {"fill_values": sq_vals}}, ) self.assertEqual(sum(m.training_in_design), 7) - self.assertEqual(m.model_space.parameters["x2"].upper, 20) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 20, + ) self.assertEqual(m.model_space.parameter_constraints, []) # Using parameter backfill values @@ -955,7 +970,10 @@ def test_set_model_space(self) -> None: search_space=ss, ) self.assertEqual(sum(m.training_in_design), 7) - self.assertEqual(m.model_space.parameters["x2"].upper, 20) + self.assertEqual( + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, + 20, + ) self.assertEqual(m.model_space.parameter_constraints, []) # Check log scale expansion with OOD trial having parameter value == 0 @@ -992,12 +1010,12 @@ def test_set_model_space(self) -> None: # Assert that the expanded model space did not include 0.0 self.assertEqual( - m.model_space.parameters["x1"].lower, + assert_is_instance(m.model_space.parameters["x1"], RangeParameter).lower, 0.0001, ) # x2 model space should still be expanded self.assertEqual( - m.model_space.parameters["x2"].upper, + assert_is_instance(m.model_space.parameters["x2"], RangeParameter).upper, 2.0, ) diff --git a/ax/adapter/tests/test_cross_validation.py b/ax/adapter/tests/test_cross_validation.py index 55908746e1d..5001bdbf608 100644 --- a/ax/adapter/tests/test_cross_validation.py +++ b/ax/adapter/tests/test_cross_validation.py @@ -9,6 +9,7 @@ import warnings from collections.abc import Iterable from itertools import product +from typing import cast from unittest import mock import numpy as np @@ -70,6 +71,7 @@ from gpytorch.distributions import MultivariateNormal from linear_operator.operators import DiagLinearOperator from pandas import DataFrame +from pyre_extensions import assert_is_instance # Number of in-design points created by _create_adapter_with_out_of_design_points() _OOD_ADAPTER_IN_DESIGN_COUNT = 3 @@ -78,9 +80,8 @@ class CrossValidationTest(TestCase): def setUp(self) -> None: super().setUp() - # pyre-ignore [9] Pyre is too picky with union types. parameterizations: list[TParameterization] = [ - {"x": x} for x in [2.0, 2.0, 3.0, 4.0] + cast(TParameterization, {"x": x}) for x in [2.0, 2.0, 3.0, 4.0] ] means = [[2.0, 4.0], [3.0, 5.0], [7.0, 8.0], [9.0, 10.0]] sems = [[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]] @@ -894,29 +895,27 @@ def test_efficient_loo_cv_with_fully_bayesian_model(self) -> None: experiment = get_branin_experiment(with_batch=True, with_completed_batch=True) # Create adapter with SaasFullyBayesianSingleTaskGP + generator = BoTorchGenerator( + surrogate=Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ + ModelConfig( + botorch_model_class=SaasFullyBayesianSingleTaskGP, + ) + ], + ), + ) + ) adapter = TorchAdapter( experiment=experiment, - generator=BoTorchGenerator( - surrogate=Surrogate( - surrogate_spec=SurrogateSpec( - model_configs=[ - ModelConfig( - botorch_model_class=SaasFullyBayesianSingleTaskGP, - ) - ], - ), - ) - ), + generator=generator, transforms=[UnitX], ) # We need to mock the MCMC fitting to avoid running actual NUTS sampling # which is very slow. Instead, we'll inject mock MCMC samples. - surrogate = adapter.generator.surrogate # pyre-ignore[16] - model = surrogate.model - - # Verify the model is a SaasFullyBayesianSingleTaskGP - self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP) + surrogate = generator.surrogate + model = assert_is_instance(surrogate.model, SaasFullyBayesianSingleTaskGP) # Get training data shape info train_X = model.train_inputs[0] diff --git a/ax/adapter/tests/test_prediction_utils.py b/ax/adapter/tests/test_prediction_utils.py index f72c20fe070..080cfe47015 100644 --- a/ax/adapter/tests/test_prediction_utils.py +++ b/ax/adapter/tests/test_prediction_utils.py @@ -83,14 +83,11 @@ def test_predict_by_features(self) -> None: @mock.patch("ax.adapter.random.RandomAdapter.predict") @mock.patch("ax.adapter.random.RandomAdapter") - # pyre-fixme[3]: Return type must be annotated. def test_predict_by_features_with_non_predicting_model( self, - # pyre-fixme[2]: Parameter must be annotated. - adapter_mock, - # pyre-fixme[2]: Parameter must be annotated. - predict_mock, - ): + adapter_mock: mock.MagicMock, + predict_mock: mock.MagicMock, + ) -> None: ax_client = _set_up_client_for_get_model_predictions_no_next_trial() _attach_completed_trials(ax_client) diff --git a/ax/adapter/tests/test_random_adapter.py b/ax/adapter/tests/test_random_adapter.py index 1cb4b4522a6..689c7ef121b 100644 --- a/ax/adapter/tests/test_random_adapter.py +++ b/ax/adapter/tests/test_random_adapter.py @@ -23,6 +23,7 @@ from ax.exceptions.core import SearchSpaceExhausted from ax.generators.random.base import RandomGenerator from ax.generators.random.sobol import SobolGenerator +from ax.generators.types import TConfig from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_data, @@ -45,7 +46,7 @@ def setUp(self) -> None: ] self.search_space = SearchSpace(self.parameters, parameter_constraints) self.experiment = Experiment(search_space=self.search_space) - self.model_gen_options = {"option": "yes"} + self.model_gen_options: TConfig = {"option": "yes"} def test_fit(self) -> None: adapter = RandomAdapter(experiment=self.experiment, generator=RandomGenerator()) @@ -79,10 +80,6 @@ def test_gen_w_constraints(self) -> None: pending_observations={}, fixed_features=ObservationFeatures({"z": 3.0}), optimization_config=None, - # pyre-fixme[6]: For 6th param expected `Optional[Dict[str, - # Union[None, Dict[str, typing.Any], OptimizationConfig, - # AcquisitionFunction, float, int, str]]]` but got `Dict[str, - # str]`. model_gen_options=self.model_gen_options, ) gen_args = mock_gen.mock_calls[0][2] @@ -129,10 +126,6 @@ def test_gen_simple(self) -> None: pending_observations={}, fixed_features=ObservationFeatures({}), optimization_config=None, - # pyre-fixme[6]: For 6th param expected `Optional[Dict[str, - # Union[None, Dict[str, typing.Any], OptimizationConfig, - # AcquisitionFunction, float, int, str]]]` but got `Dict[str, - # str]`. model_gen_options=self.model_gen_options, ) gen_args = mock_gen.mock_calls[0][2] diff --git a/ax/adapter/tests/test_torch_adapter.py b/ax/adapter/tests/test_torch_adapter.py index 8c5288c865c..56d52e30271 100644 --- a/ax/adapter/tests/test_torch_adapter.py +++ b/ax/adapter/tests/test_torch_adapter.py @@ -632,7 +632,7 @@ def test_convert_experiment_data(self) -> None: ordinal_features=[2], discrete_choices={2: list(range(0, 11))}, task_features=[2] if use_task else [], - target_values={2: 0} if use_task else {}, # pyre-ignore + target_values={2: 0.0} if use_task else {}, ) converted_datasets, ordered_outcomes, _ = adapter._convert_experiment_data( experiment_data=experiment_data, diff --git a/ax/adapter/tests/test_torch_moo_adapter.py b/ax/adapter/tests/test_torch_moo_adapter.py index 6e23fb669b4..b01c3dbcd2e 100644 --- a/ax/adapter/tests/test_torch_moo_adapter.py +++ b/ax/adapter/tests/test_torch_moo_adapter.py @@ -322,8 +322,7 @@ def test_hypervolume(self, _, cuda: bool = False) -> None: ) for trial in exp.trials.values(): trial.mark_running(no_runner_required=True).mark_completed() - # pyre-fixme[16]: Optional type has no attribute `metrics`. - metrics_dict = exp.optimization_config.metrics + metrics_dict = none_throws(exp.optimization_config).metrics # Objective thresholds and synthetic observations chosen to have closed-form # hypervolumes to test. objective_thresholds = [ diff --git a/ax/adapter/tests/test_utils.py b/ax/adapter/tests/test_utils.py index a4051f12e17..30853afd2d5 100644 --- a/ax/adapter/tests/test_utils.py +++ b/ax/adapter/tests/test_utils.py @@ -85,7 +85,6 @@ def test_extract_outcome_constraints(self) -> None: OutcomeConstraint(metric=Metric("m1"), op=ComparisonOp.LEQ, bound=0) ] res = extract_outcome_constraints(outcome_constraints, outcomes) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. self.assertEqual(res[0].shape, (1, 3)) self.assertListEqual(list(res[0][0]), [1, 0, 0]) self.assertEqual(res[1][0][0], 0) @@ -137,10 +136,8 @@ def test_extract_objective_thresholds(self) -> None: outcomes=outcomes, ) expected_obj_t_not_nan = np.array([2.0, 3.0, 4.0]) - # pyre-fixme[16]: Optional type has no attribute `__getitem__`. self.assertTrue(np.array_equal(obj_t[:3], expected_obj_t_not_nan[:3])) self.assertTrue(np.isnan(obj_t[-1])) - # pyre-fixme[16]: Optional type has no attribute `shape`. self.assertEqual(obj_t.shape[0], 4) # Returns NaN for objectives without a threshold. diff --git a/ax/adapter/transforms/tests/test_base_transform.py b/ax/adapter/transforms/tests/test_base_transform.py index e89cec7080f..c2e932daeef 100644 --- a/ax/adapter/transforms/tests/test_base_transform.py +++ b/ax/adapter/transforms/tests/test_base_transform.py @@ -18,6 +18,7 @@ from ax.core.objective import Objective from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig +from ax.core.types import TParameterization from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment @@ -66,10 +67,10 @@ def test_TransformObservations(self) -> None: means = np.array([3.0, 4.0]) metric_signatures = ["a", "b"] covariance = np.array([[1.0, 2.0], [3.0, 4.0]]) - parameters = {"x": 1.0, "y": "cat"} + parameters: TParameterization = {"x": 1.0, "y": "cat"} arm_name = "armmy" observation = Observation( - features=ObservationFeatures(parameters=parameters), # pyre-ignore + features=ObservationFeatures(parameters=parameters), data=ObservationData( metric_signatures=metric_signatures, means=means, covariance=covariance ), diff --git a/ax/adapter/transforms/tests/test_cast_transform.py b/ax/adapter/transforms/tests/test_cast_transform.py index 0dc6940f10c..d41de7085ee 100644 --- a/ax/adapter/transforms/tests/test_cast_transform.py +++ b/ax/adapter/transforms/tests/test_cast_transform.py @@ -35,6 +35,7 @@ ) from pandas import DataFrame from pandas.testing import assert_frame_equal +from pyre_extensions import none_throws class CastTransformTest(TestCase): @@ -179,8 +180,7 @@ def test_transform_observation_features_HSS(self) -> None: self.assertIn(p_name, obsf.parameters) # Check that full parameterization is recorded in metadata self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `get`. - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss.parameters, ) @@ -197,7 +197,7 @@ def test_transform_observation_features_HSS(self) -> None: self.assertIn(p_name, obsf.parameters) # Check that full parameterization is recorded in metadata self.assertEqual( - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss.parameters, ) @@ -245,8 +245,7 @@ def test_untransform_observation_features_HSS(self) -> None: }, ) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `get`. - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss.parameters, ) @@ -264,7 +263,7 @@ def test_untransform_observation_features_HSS(self) -> None: }, ) self.assertEqual( - obsf.metadata.get(Keys.FULL_PARAMETERIZATION), + none_throws(obsf.metadata).get(Keys.FULL_PARAMETERIZATION), self.obs_feats_hss_2.parameters, ) diff --git a/ax/adapter/transforms/tests/test_choice_encode_transform.py b/ax/adapter/transforms/tests/test_choice_encode_transform.py index 834d6714f80..5bf8ed02208 100644 --- a/ax/adapter/transforms/tests/test_choice_encode_transform.py +++ b/ax/adapter/transforms/tests/test_choice_encode_transform.py @@ -247,8 +247,9 @@ def test_hss_dependents_are_preserved(self) -> None: # x0 should be untouched because it's a fixed parameter. self.assertIsInstance(hss.parameters["x0"], FixedParameter) self.assertEqual(hss.parameters["x0"].parameter_type, ParameterType.BOOL) - # pyre-ignore[16] # Pyre doesn't understand fixed parameters have `.value` - self.assertEqual(hss.parameters["x0"].value, True) + self.assertEqual( + assert_is_instance(hss.parameters["x0"], FixedParameter).value, True + ) self.assertEqual(hss.parameters["x0"].dependents, {True: ["x1", "x2"]}) self.assertFalse(hss.parameters["x1"].is_hierarchical) diff --git a/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py b/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py index fee4d0bab9d..34d9ae1944c 100644 --- a/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py +++ b/ax/adapter/transforms/tests/test_int_range_to_choice_transform.py @@ -72,7 +72,7 @@ def test_num_choices(self) -> None: "e", lower=3, upper=5, parameter_type=ParameterType.INT ), } - search_space = SearchSpace(parameters=parameters.values()) # pyre-ignore[6] + search_space = SearchSpace(parameters=list(parameters.values())) # Don't specify max_choices (should be set to inf) t = IntRangeToChoice(search_space=search_space) diff --git a/ax/adapter/transforms/tests/test_logit_transform.py b/ax/adapter/transforms/tests/test_logit_transform.py index 38fb28afc8f..1c567c79a76 100644 --- a/ax/adapter/transforms/tests/test_logit_transform.py +++ b/ax/adapter/transforms/tests/test_logit_transform.py @@ -18,6 +18,7 @@ from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_experiment_with_observations from pandas.testing import assert_frame_equal, assert_series_equal +from pyre_extensions import assert_is_instance from scipy.special import expit, logit @@ -107,16 +108,19 @@ def test_InvalidSettings(self) -> None: def test_TransformSearchSpace(self) -> None: ss2 = deepcopy(self.search_space) ss2 = self.t.transform_search_space(ss2) - # pyre-fixme[16]: `Parameter` has no attribute `lower`. - self.assertEqual(ss2.parameters["x"].lower, logit(0.9)) - # pyre-fixme[16]: `Parameter` has no attribute `upper`. - self.assertEqual(ss2.parameters["x"].upper, logit(0.999)) + self.assertEqual( + assert_is_instance(ss2.parameters["x"], RangeParameter).lower, logit(0.9) + ) + self.assertEqual( + assert_is_instance(ss2.parameters["x"], RangeParameter).upper, logit(0.999) + ) t2 = Logit(search_space=self.search_space_with_target) ss_target = deepcopy(self.search_space_with_target) t2.transform_search_space(ss_target) self.assertEqual(ss_target.parameters["x"].target_value, logit(0.123)) - self.assertEqual(ss_target.parameters["x"].lower, logit(0.1)) - self.assertEqual(ss_target.parameters["x"].upper, logit(0.3)) + x_param = assert_is_instance(ss_target.parameters["x"], RangeParameter) + self.assertEqual(x_param.lower, logit(0.1)) + self.assertEqual(x_param.upper, logit(0.3)) def test_transform_experiment_data(self) -> None: parameterizations = [ diff --git a/ax/adapter/transforms/tests/test_metrics_as_task_transform.py b/ax/adapter/transforms/tests/test_metrics_as_task_transform.py index 44472179380..dd287754588 100644 --- a/ax/adapter/transforms/tests/test_metrics_as_task_transform.py +++ b/ax/adapter/transforms/tests/test_metrics_as_task_transform.py @@ -14,6 +14,7 @@ from ax.core.parameter import ChoiceParameter from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_search_space_for_range_values +from pyre_extensions import assert_is_instance class MetricsAsTaskTransformTest(TestCase): @@ -125,10 +126,10 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(len(new_ss.parameters), 3) new_param = new_ss.parameters["METRIC_TASK"] self.assertIsInstance(new_param, ChoiceParameter) + new_param_choice = assert_is_instance(new_param, ChoiceParameter) self.assertEqual( - # pyre-fixme[16]: `Parameter` has no attribute `values`. - new_param.values, + new_param_choice.values, ["TARGET", "metric1", "metric2"], ) - self.assertTrue(new_param.is_task) # pyre-ignore + self.assertTrue(new_param_choice.is_task) self.assertEqual(new_param.target_value, "TARGET") diff --git a/ax/adapter/transforms/tests/test_one_hot_transform.py b/ax/adapter/transforms/tests/test_one_hot_transform.py index fb7f41accf4..f685c6b6d67 100644 --- a/ax/adapter/transforms/tests/test_one_hot_transform.py +++ b/ax/adapter/transforms/tests/test_one_hot_transform.py @@ -25,6 +25,7 @@ from ax.utils.testing.core_stubs import get_experiment_with_observations from pandas import DataFrame from pandas.testing import assert_frame_equal +from pyre_extensions import assert_is_instance class OneHotTransformTest(TestCase): @@ -130,10 +131,18 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(ss2.parameters["d"].parameter_type, ParameterType.FLOAT) # Parameter range fixed to [0,1]. - # pyre-fixme[16]: `Parameter` has no attribute `lower`. - self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "0"].lower, 0.0) - # pyre-fixme[16]: `Parameter` has no attribute `upper`. - self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "1"].upper, 1.0) + self.assertEqual( + assert_is_instance( + ss2.parameters["b" + OH_PARAM_INFIX + "0"], RangeParameter + ).lower, + 0.0, + ) + self.assertEqual( + assert_is_instance( + ss2.parameters["b" + OH_PARAM_INFIX + "1"], RangeParameter + ).upper, + 1.0, + ) self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.BOOL) # Ensure we error if we try to transform a fidelity parameter diff --git a/ax/adapter/transforms/tests/test_task_encode_transform.py b/ax/adapter/transforms/tests/test_task_encode_transform.py index 42e95d25723..8a985685c32 100644 --- a/ax/adapter/transforms/tests/test_task_encode_transform.py +++ b/ax/adapter/transforms/tests/test_task_encode_transform.py @@ -13,6 +13,7 @@ from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance class TaskChoiceToIntTaskChoiceTransformTest(TestCase): @@ -73,8 +74,9 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(ss2.parameters["b"].parameter_type, ParameterType.FLOAT) self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.INT) - # pyre-fixme[16]: `Parameter` has no attribute `values`. - self.assertEqual(ss2.parameters["c"].values, [0, 1]) + self.assertEqual( + assert_is_instance(ss2.parameters["c"], ChoiceParameter).values, [0, 1] + ) self.assertEqual(ss2.parameters["c"].target_value, 0) self.assertEqual(ss2.parameters["c"].dependents, {0: ["b"]}) diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index 8a2c201f4b9..473f2ac68c6 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -212,11 +212,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.STAGED], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.STAGED + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.STAGED + ) ) self.assertIsNotNone(self.batch.time_staged) self.assertTrue(self.batch.status.is_deployed) @@ -240,11 +240,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.RUNNING], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.RUNNING + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.RUNNING + ) ) self.assertIsNotNone(self.batch.time_run_started) self.assertTrue(self.batch.status.expecting_data) @@ -261,11 +261,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.COMPLETED], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.COMPLETED + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.COMPLETED + ) ) self.assertIsNotNone(self.batch.time_completed) self.assertTrue(self.batch.status.is_terminal) @@ -296,11 +296,11 @@ def test_BatchLifecycle(self) -> None: self.experiment.trial_indices_by_status[TrialStatus.CANDIDATE], {0} ) self.assertTrue( - # pyre-fixme[6]: For 1st param expected `Iterable[object]` but got - # `bool`. - all(len(idcs) == 0) - for status, idcs in self.experiment.trial_indices_by_status.items() - if status != TrialStatus.CANDIDATE + all( + len(idcs) == 0 + for status, idcs in self.experiment.trial_indices_by_status.items() + if status != TrialStatus.CANDIDATE + ) ) def test_AbandonBatchTrial(self) -> None: @@ -592,11 +592,9 @@ def test_get_candidate_metadata_from_all_generator_runs(self) -> None: # Check that if we add cand. metadata to gr_2, it will appear in cand. # metadata for the batch. gr_3 = get_generator_run2() - new_cand_metadata = { + new_cand_metadata: dict[str, dict[str, str] | None] | None = { a.signature: {"md_key": f"md_val_{a.signature}"} for a in gr_3.arms } - # pyre-fixme[8]: Attribute has type `Optional[Dict[str, Optional[Dict[str, - # typing.Any]]]]`; used as `Dict[str, Dict[str, str]]`. gr_3._candidate_metadata_by_arm_signature = new_cand_metadata self.batch.add_generator_run(gr_3) gr_3 = self.batch._generator_runs[-1] diff --git a/ax/core/tests/test_generator_run.py b/ax/core/tests/test_generator_run.py index 9360c93293f..ed3b22900d8 100644 --- a/ax/core/tests/test_generator_run.py +++ b/ax/core/tests/test_generator_run.py @@ -17,6 +17,7 @@ get_optimization_config, get_search_space, ) +from pyre_extensions import none_throws GENERATOR_RUN_STR = "GeneratorRun(3 arms, total weight 3.0)" @@ -31,7 +32,7 @@ def setUp(self) -> None: self.search_space = get_search_space() self.arms = get_arms() - self.weights = [2, 1, 1] + self.weights: list[float] = [2, 1, 1] self.unweighted_run = GeneratorRun( arms=self.arms, optimization_config=self.optimization_config, @@ -42,8 +43,6 @@ def setUp(self) -> None: ) self.weighted_run = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=self.optimization_config, search_space=self.search_space, @@ -56,13 +55,13 @@ def setUp(self) -> None: def test_Init(self) -> None: self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `outcome_constraints`. - len(self.unweighted_run.optimization_config.outcome_constraints), + len( + none_throws(self.unweighted_run.optimization_config).outcome_constraints + ), len(self.optimization_config.outcome_constraints), ) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `parameters`. - len(self.unweighted_run.search_space.parameters), + len(none_throws(self.unweighted_run.search_space).parameters), len(self.search_space.parameters), ) self.assertEqual(str(self.unweighted_run), GENERATOR_RUN_STR) @@ -120,8 +119,6 @@ def test_ModelPredictions(self) -> None: ) run_no_model_predictions = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=get_optimization_config(), search_space=get_search_space(), @@ -150,8 +147,6 @@ def test_ParamDf(self) -> None: def test_BestArm(self) -> None: generator_run = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=get_optimization_config(), search_space=get_search_space(), @@ -166,8 +161,6 @@ def test_GenMetadata(self) -> None: gm = {"hello": "world"} generator_run = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, optimization_config=get_optimization_config(), search_space=get_search_space(), @@ -178,14 +171,10 @@ def test_GenMetadata(self) -> None: def test_Sortable(self) -> None: generator_run1 = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, ) generator_run2 = GeneratorRun( arms=self.arms, - # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but got - # `List[int]`. weights=self.weights, ) self.assertTrue(generator_run1 < generator_run2) diff --git a/ax/core/tests/test_objective.py b/ax/core/tests/test_objective.py index 57d21c66714..6e28cc39b09 100644 --- a/ax/core/tests/test_objective.py +++ b/ax/core/tests/test_objective.py @@ -64,8 +64,7 @@ def test_Init(self) -> None: def test_MultiObjective(self) -> None: with self.assertRaises(NotImplementedError): - # pyre-fixme[7]: Expected `None` but got `Metric`. - return self.multi_objective.metric + self.multi_objective.metric self.assertEqual(self.multi_objective.metrics, list(self.metrics.values())) minimizes = [obj.minimize for obj in self.multi_objective.objectives] @@ -106,8 +105,7 @@ def test_MultiObjective(self) -> None: def test_ScalarizedObjective(self) -> None: with self.assertRaises(NotImplementedError): - # pyre-fixme[7]: Expected `None` but got `Metric`. - return self.scalarized_objective.metric + self.scalarized_objective.metric self.assertEqual( self.scalarized_objective.metrics, [self.metrics["m1"], self.metrics["m2"]] diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c18ccf25deb..abbc19cea32 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -35,22 +35,19 @@ class ObservationsTest(TestCase): def test_ObservationFeatures(self) -> None: - t = np.datetime64("now") + t = pd.Timestamp.now() + obsf = ObservationFeatures( + parameters={"x": 0, "y": "a"}, + trial_index=2, + start_time=t, + end_time=t, + ) attrs = { "parameters": {"x": 0, "y": "a"}, "trial_index": 2, "start_time": t, "end_time": t, } - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Dict[str, typing.Any]]` - # but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[int64]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Timestamp]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - obsf = ObservationFeatures(**attrs) for k, v in attrs.items(): self.assertEqual(getattr(obsf, k), v) printstr = ( @@ -58,29 +55,21 @@ def test_ObservationFeatures(self) -> None: f"start_time={t}, end_time={t})" ) self.assertEqual(repr(obsf), printstr) - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Dict[str, typing.Any]]` - # but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[int64]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Timestamp]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - obsf2 = ObservationFeatures(**attrs) + obsf2 = ObservationFeatures( + parameters={"x": 0, "y": "a"}, + trial_index=2, + start_time=t, + end_time=t, + ) self.assertEqual(hash(obsf), hash(obsf2)) a = {obsf, obsf2} self.assertEqual(len(a), 1) self.assertEqual(obsf, obsf2) - attrs.pop("trial_index") - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Dict[str, typing.Any]]` - # but got `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[int64]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - # pyre-fixme[6]: For 1st param expected `Optional[Timestamp]` but got - # `Union[Dict[str, Union[int, str]], int, datetime64]`. - obsf3 = ObservationFeatures(**attrs) + obsf3 = ObservationFeatures( + parameters={"x": 0, "y": "a"}, + start_time=t, + end_time=t, + ) self.assertNotEqual(obsf, obsf3) self.assertFalse(obsf == 1) @@ -105,12 +94,9 @@ def test_ObservationFeaturesFromArm(self) -> None: self.assertEqual(obsf.trial_index, 3) def test_UpdateFeatures(self) -> None: - parameters = {"x": 0, "y": "a"} - new_parameters = {"z": "foo"} + parameters: TParameterization = {"x": 0, "y": "a"} + new_parameters: TParameterization = {"z": "foo"} - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, - # int, str]]` but got `Dict[str, Union[int, str]]`. - # pyre-fixme[6]: For 2nd param expected `Optional[int64]` but got `int`. obsf = ObservationFeatures(parameters=parameters, trial_index=3) # Ensure None trial_index doesn't override existing value @@ -119,8 +105,6 @@ def test_UpdateFeatures(self) -> None: # Test override new_obsf = ObservationFeatures( - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Dict[str, str]`. parameters=new_parameters, trial_index=4, start_time=pd.Timestamp("2005-02-25"), @@ -133,16 +117,19 @@ def test_UpdateFeatures(self) -> None: self.assertEqual(obsf.end_time, pd.Timestamp("2005-02-26")) def test_ObservationData(self) -> None: + metric_signatures = ["a", "b"] + means = np.array([4.0, 5.0]) + covariance = np.array([[1.0, 4.0], [3.0, 6.0]]) + obsd = ObservationData( + metric_signatures=metric_signatures, + means=means, + covariance=covariance, + ) attrs = { - "metric_signatures": ["a", "b"], - "means": np.array([4.0, 5.0]), - "covariance": np.array([[1.0, 4.0], [3.0, 6.0]]), + "metric_signatures": metric_signatures, + "means": means, + "covariance": covariance, } - # pyre-fixme[6]: For 1st param expected `List[str]` but got - # `Union[List[str], ndarray]`. - # pyre-fixme[6]: For 1st param expected `ndarray` but got `Union[List[str], - # ndarray]`. - obsd = ObservationData(**attrs) self.assertEqual(obsd.metric_signatures, attrs["metric_signatures"]) self.assertTrue(np.array_equal(obsd.means, attrs["means"])) self.assertTrue(np.array_equal(obsd.covariance, attrs["covariance"])) @@ -258,19 +245,18 @@ def test_ObservationsFromData(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[Dict[str, Union[int, str]], float, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[Dict[str, Union[int, str]], float, - # str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} trials = { obs["trial_index"]: Trial( - experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) + experiment, + GeneratorRun(arms=[arms[assert_is_instance(obs["arm_name"], str)]]), ) for obs in truth } @@ -375,8 +361,7 @@ def test_ObservationsFromMapData(self) -> None: arms = [ Arm( name=assert_is_instance(obs["arm_name"], str), - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool,... - parameters=obs["parameters"], + parameters=assert_is_instance(obs["parameters"], dict), ) for obs in truth ] @@ -419,10 +404,18 @@ def test_ObservationsFromMapData(self) -> None: self.assertEqual(obs.features.trial_index, t["trial_index"]) self.assertEqual(obs.data.metric_signatures, [t["metric_name"]]) self.assertEqual(obs.data.metric_signatures, [t["metric_signature"]]) - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty... - self.assertTrue(np.array_equal(obs.data.means, t["mean_t"])) - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty... - self.assertTrue(np.array_equal(obs.data.covariance, t["covariance_t"])) + self.assertTrue( + np.array_equal( + obs.data.means, + np.asarray(assert_is_instance(t["mean_t"], np.ndarray)), + ) + ) + self.assertTrue( + np.array_equal( + obs.data.covariance, + np.asarray(assert_is_instance(t["covariance_t"], np.ndarray)), + ) + ) self.assertEqual(obs.arm_name, t["arm_name"]) self.assertEqual(obs.features.metadata, {"step": t["step"]}) @@ -514,37 +507,28 @@ def test_ObservationsFromDataAbandoned(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[Dict[str, Union[float, str]], Dict[str, Union[int, str]], float, - # ndarray, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[Dict[str, Union[float, str]], - # Dict[str, Union[int, str]], float, ndarray, str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} - trials = { - obs["trial_index"]: ( - Trial(experiment, GeneratorRun(arms=[arms[obs["arm_name"]]])) + trials: dict[int, Trial | BatchTrial] = { + assert_is_instance(obs["trial_index"], int): ( + Trial( + experiment, + GeneratorRun(arms=[arms[assert_is_instance(obs["arm_name"], str)]]), + ) ) for obs in truth[:-1] - # pyre-fixme[16]: Item `Dict` of `Union[Dict[str, typing.Union[float, - # str]], Dict[str, typing.Union[int, str]], float, ndarray, str]` has no - # attribute `startswith`. - if not obs["arm_name"].startswith("2") + if not assert_is_instance(obs["arm_name"], str).startswith("2") } batch = BatchTrial(experiment, GeneratorRun(arms=[arms["2_0"], arms["2_1"]])) - # pyre-fixme[6]: For 1st param expected - # `SupportsKeysAndGetItem[Union[Dict[str, Union[float, str]], Dict[str, - # Union[int, str]], float, ndarray, str], Trial]` but got `Dict[int, - # BatchTrial]`. - trials.update({2: batch}) - # pyre-fixme[16]: Optional type has no attribute `mark_abandoned`. - trials.get(1).mark_abandoned() - # pyre-fixme[16]: Optional type has no attribute `mark_arm_abandoned`. - trials.get(2).mark_arm_abandoned(arm_name="2_1") + trials[2] = batch + none_throws(trials.get(1)).mark_abandoned() + assert_is_instance(trials.get(2), BatchTrial).mark_arm_abandoned(arm_name="2_1") type(experiment).arms_by_name = PropertyMock(return_value=arms) type(experiment).trials = PropertyMock(return_value=trials) type(experiment).metrics = PropertyMock( @@ -627,19 +611,18 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[None, Dict[str, Union[int, str]], float, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[None, Dict[str, Union[int, str]], - # float, str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() experiment._trial_indices_by_status = {status: set() for status in TrialStatus} trials = { obs["trial_index"]: Trial( - experiment, GeneratorRun(arms=[arms[obs["arm_name"]]]) + experiment, + GeneratorRun(arms=[arms[assert_is_instance(obs["arm_name"], str)]]), ) for obs in truth } @@ -789,11 +772,19 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial( self.assertEqual( obs.data.metric_signatures, obs_truth["metric_signatures"][i] ) - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty... - self.assertTrue(np.array_equal(obs.data.means, obs_truth["means"][i])) self.assertTrue( - # pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtyp... - np.array_equal(obs.data.covariance, obs_truth["covariance"][i]) + np.array_equal( + obs.data.means, + # pyre-ignore[6]: numpy stubs type mismatch. + assert_is_instance(obs_truth["means"][i], np.ndarray), + ) + ) + self.assertTrue( + np.array_equal( + obs.data.covariance, + # pyre-ignore[6]: numpy stubs type mismatch. + assert_is_instance(obs_truth["covariance"][i], np.ndarray), + ) ) self.assertEqual(obs.arm_name, obs_truth["arm_name"][i]) self.assertEqual(obs.arm_name, obs_truth["arm_name"][i]) @@ -875,12 +866,10 @@ def test_ObservationsWithCandidateMetadata(self) -> None: }, ] arms = { - # pyre-fixme[6]: For 1st param expected `Optional[str]` but got - # `Union[Dict[str, Union[int, str]], float, str]`. - # pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, - # float, int, str]]` but got `Union[Dict[str, Union[int, str]], float, - # str]`. - obs["arm_name"]: Arm(name=obs["arm_name"], parameters=obs["parameters"]) + assert_is_instance(obs["arm_name"], str): Arm( + name=assert_is_instance(obs["arm_name"], str), + parameters=assert_is_instance(obs["parameters"], dict), + ) for obs in truth } experiment = Mock() @@ -889,9 +878,9 @@ def test_ObservationsWithCandidateMetadata(self) -> None: obs["trial_index"]: Trial( experiment, GeneratorRun( - arms=[arms[obs["arm_name"]]], + arms=[arms[assert_is_instance(obs["arm_name"], str)]], candidate_metadata_by_arm_signature={ - arms[obs["arm_name"]].signature: { + arms[assert_is_instance(obs["arm_name"], str)].signature: { SOME_METADATA_KEY: f"value_{obs['trial_index']}" } }, @@ -919,8 +908,7 @@ def test_ObservationsWithCandidateMetadata(self) -> None: observations = observations_from_data(experiment, data) for observation in observations: self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `get`. - observation.features.metadata.get(SOME_METADATA_KEY), + none_throws(observation.features.metadata).get(SOME_METADATA_KEY), f"value_{observation.features.trial_index}", ) diff --git a/ax/core/tests/test_outcome_constraint.py b/ax/core/tests/test_outcome_constraint.py index fcc6ed79cf5..f288ece1703 100644 --- a/ax/core/tests/test_outcome_constraint.py +++ b/ax/core/tests/test_outcome_constraint.py @@ -253,8 +253,7 @@ def test_RaiseError(self) -> None: ) with self.assertRaises(NotImplementedError): - # pyre-fixme[7]: Expected `None` but got `Metric`. - return self.constraint.metric + self.constraint.metric with self.assertRaises(NotImplementedError): self.constraint.metric = self.metrics[0] diff --git a/ax/core/tests/test_parameter_constraint.py b/ax/core/tests/test_parameter_constraint.py index 822046e1ec8..30bc1be4c2c 100644 --- a/ax/core/tests/test_parameter_constraint.py +++ b/ax/core/tests/test_parameter_constraint.py @@ -82,16 +82,12 @@ def test_Repr(self) -> None: self.assertEqual(str(self.constraint), self.constraint_repr) def test_Validate(self) -> None: - parameters = {"x": 4, "z": 3} + parameters: dict[str, float | int] = {"x": 4, "z": 3} with self.assertRaises(ValueError): - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[float, int]]` - # but got `Dict[str, int]`. self.constraint.check(parameters) # check slack constraint parameters = {"x": 4, "y": 1} - # pyre-fixme[6]: For 1st param expected `Dict[str, Union[float, int]]` but - # got `Dict[str, int]`. self.assertTrue(self.constraint.check(parameters)) # check tight constraint (within numerical tolerance) diff --git a/ax/core/tests/test_runner.py b/ax/core/tests/test_runner.py index 24edb5eb61e..40a34d33276 100644 --- a/ax/core/tests/test_runner.py +++ b/ax/core/tests/test_runner.py @@ -15,8 +15,7 @@ class DummyRunner(Runner): - # pyre-fixme[3]: Return type must be annotated. - def run(self, trial: BaseTrial): + def run(self, trial: BaseTrial) -> dict[str, str]: return {"metadatum": f"value_for_trial_{trial.index}"} diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index 33eef439201..d638968896d 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -114,13 +114,12 @@ def test_basic_properties(self) -> None: def test_adding_new_trials(self) -> None: new_arm = get_arms()[1] - cand_metadata = {new_arm.signature: {"a": "b"}} + cand_metadata: dict[str, dict[str, str] | None] = { + new_arm.signature: {"a": "b"} + } new_trial = self.experiment.new_trial( generator_run=GeneratorRun( arms=[new_arm], - # pyre-fixme[6]: For 2nd param expected `Optional[Dict[str, - # Optional[Dict[str, typing.Any]]]]` but got `Dict[str, Dict[str, - # str]]`. candidate_metadata_by_arm_signature=cand_metadata, ) ) @@ -313,15 +312,13 @@ def stop(self, trial, reason): f"{BaseTrial.__module__}.{BaseTrial.__name__}.lookup_data", return_value=TEST_DATA, ) - # pyre-fixme[3]: Return type must be annotated. - def test_objective_mean(self, _mock): + def test_objective_mean(self, _mock: Mock) -> None: self.assertEqual(self.trial.objective_mean, 1.0) @patch( f"{BaseTrial.__module__}.{BaseTrial.__name__}.lookup_data", return_value=Data() ) - # pyre-fixme[3]: Return type must be annotated. - def test_objective_mean_empty_df(self, _mock): + def test_objective_mean_empty_df(self, _mock: Mock) -> None: with self.assertRaisesRegex(ValueError, "not yet in data for trial."): self.assertIsNone(self.trial.objective_mean) diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index f0810e2949c..ccd64b6c4ec 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -22,6 +22,7 @@ from ax.core.observation import ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import OutcomeConstraint +from ax.core.trial import Trial from ax.core.trial_status import TrialStatus from ax.core.types import ComparisonOp from ax.core.utils import ( @@ -51,7 +52,7 @@ get_experiment, get_hierarchical_search_space_experiment, ) -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws class UtilsTest(TestCase): @@ -1151,7 +1152,7 @@ def test_curve_data(self) -> None: ) trial = exp.trials[0] trial.mark_running(no_runner_required=True) - arm_name = trial.arm.name # pyre-ignore[16] + arm_name = none_throws(assert_is_instance(trial, Trial).arm).name # Both metrics present at various steps → COMPLETE. df_both = pd.DataFrame( @@ -1183,7 +1184,7 @@ def test_curve_data(self) -> None: exp2.optimization_config = none_throws(exp.optimization_config) trial2 = exp2.trials[0] trial2.mark_running(no_runner_required=True) - arm_name2 = trial2.arm.name # pyre-ignore[16] + arm_name2 = none_throws(assert_is_instance(trial2, Trial).arm).name df_partial = pd.DataFrame( [ {