88import json
99from collections .abc import Iterable , Sequence
1010from logging import Logger
11- from typing import Any , Literal , Self
11+ from typing import Any , cast , Literal , Self
1212
1313import numpy as np
1414import pandas as pd
3939from ax .core .runner import Runner
4040from ax .core .trial import Trial
4141from ax .core .trial_status import TrialStatus # Used as a return type
42+ from ax .core .types import TParameterization as CoreTParameterization
4243from ax .early_stopping .strategies import (
4344 BaseEarlyStoppingStrategy ,
4445 PercentileEarlyStoppingStrategy ,
@@ -183,8 +184,7 @@ def configure_optimization(
183184 pruning_target_arm : Arm | None = None
184185 if pruning_target_parameterization is not None :
185186 self ._experiment .search_space .validate_membership (
186- # pyre-fixme[6]: Core Ax TParameterization is dict not Mapping
187- parameters = pruning_target_parameterization
187+ parameters = cast (CoreTParameterization , pruning_target_parameterization )
188188 )
189189 pruning_target_arm = Arm (
190190 parameters = pruning_target_parameterization , name = "pruning_target"
@@ -442,9 +442,9 @@ def get_next_trials(
442442 experiment = self ._experiment ,
443443 n = 1 ,
444444 fixed_features = (
445- # pyre-fixme[6]: Type narrowing broken because core Ax
446- # TParameterization is dict not Mapping
447- ObservationFeatures ( parameters = fixed_parameters )
445+ ObservationFeatures (
446+ parameters = cast ( CoreTParameterization , fixed_parameters )
447+ )
448448 if fixed_parameters is not None
449449 else None
450450 ),
@@ -483,9 +483,10 @@ def get_next_trials(
483483 experiment = self ._experiment , trials = trials
484484 )
485485
486- # pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
487- # None, but we do not allow this in the API.
488- return {trial .index : none_throws (trial .arm ).parameters for trial in trials }
486+ return {
487+ trial .index : cast (TParameterization , none_throws (trial .arm ).parameters )
488+ for trial in trials
489+ }
489490
490491 def complete_trial (
491492 self ,
@@ -573,9 +574,7 @@ def attach_trial(
573574 The index of the attached trial.
574575 """
575576 _ , trial_index = self ._experiment .attach_trial (
576- # pyre-fixme[6]: Type narrowing broken because core Ax TParameterization
577- # is dict not Mapping
578- parameterizations = [parameters ],
577+ parameterizations = [cast (CoreTParameterization , parameters )],
579578 arm_names = [arm_name ] if arm_name else None ,
580579 )
581580
@@ -888,13 +887,14 @@ def get_best_parameterization(
888887 )
889888 )
890889
891- # pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
892- # None but we do not allow this in the API.
893- return BestPointMixin ._to_best_point_tuple (
894- experiment = self ._experiment ,
895- trial_index = trial_index ,
896- parameterization = parameterization ,
897- model_prediction = model_prediction ,
890+ return cast (
891+ tuple [TParameterization , TOutcome , int , str ],
892+ BestPointMixin ._to_best_point_tuple (
893+ experiment = self ._experiment ,
894+ trial_index = trial_index ,
895+ parameterization = parameterization ,
896+ model_prediction = model_prediction ,
897+ ),
898898 )
899899
900900 def get_pareto_frontier (
@@ -945,14 +945,15 @@ def get_pareto_frontier(
945945 use_model_predictions = use_model_predictions ,
946946 )
947947
948- # pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
949- # None but we do not allow this in the API.
950948 return [
951- BestPointMixin ._to_best_point_tuple (
952- experiment = self ._experiment ,
953- trial_index = trial_index ,
954- parameterization = parameterization ,
955- model_prediction = model_prediction ,
949+ cast (
950+ tuple [TParameterization , TOutcome , int , str ],
951+ BestPointMixin ._to_best_point_tuple (
952+ experiment = self ._experiment ,
953+ trial_index = trial_index ,
954+ parameterization = parameterization ,
955+ model_prediction = model_prediction ,
956+ ),
956957 )
957958 for trial_index , (parameterization , model_prediction ) in frontier .items ()
958959 ]
@@ -978,9 +979,9 @@ def predict(
978979 try :
979980 mean , covariance = none_throws (self ._generation_strategy .adapter ).predict (
980981 observation_features = [
981- # pyre-fixme[6]: Core Ax allows users to specify TParameterization
982- # values as None but we do not allow this in the API.
983- ObservationFeatures ( parameters = parameters )
982+ ObservationFeatures (
983+ parameters = cast ( CoreTParameterization , parameters )
984+ )
984985 for parameters in points
985986 ]
986987 )
0 commit comments