10
10
from ax .core .batch_trial import BatchTrial
11
11
from ax .core .data import Data
12
12
from ax .core .experiment import Experiment
13
+ from ax .core .objective import MultiObjective
13
14
from ax .core .optimization_config import OptimizationConfig
14
15
from ax .core .trial import Trial
15
16
from ax .core .types import ComparisonOp
@@ -41,17 +42,25 @@ def get_missing_metrics(
41
42
Returns:
42
43
A NamedTuple(missing_objective, Dict[str, missing_outcome_constraint])
43
44
"""
44
- objective_name = optimization_config .objective .metric .name
45
+ objective = optimization_config .objective
46
+ if isinstance (objective , MultiObjective ): # pragma: no cover
47
+ objective_metric_names = [m .name for m in objective .metrics ]
48
+ else :
49
+ objective_metric_names = [optimization_config .objective .metric .name ]
50
+
45
51
outcome_constraints_metric_names = [
46
52
outcome_constraint .metric .name
47
53
for outcome_constraint in optimization_config .outcome_constraints
48
54
]
49
- missing_objective = _get_missing_arm_trial_pairs (data , objective_name )
55
+ missing_objectives = {
56
+ objective_metric_name : _get_missing_arm_trial_pairs (data , objective_metric_name )
57
+ for objective_metric_name in objective_metric_names
58
+ }
50
59
missing_outcome_constraints = get_missing_metrics_by_name (
51
60
data , outcome_constraints_metric_names
52
61
)
53
62
all_metric_names = set (data .df ["metric_name" ])
54
- optimization_config_metric_names = { objective_name } .union (
63
+ optimization_config_metric_names = set ( missing_objectives . keys ()) .union (
55
64
outcome_constraints_metric_names
56
65
)
57
66
missing_tracking_metric_names = all_metric_names .difference (
@@ -61,9 +70,7 @@ def get_missing_metrics(
61
70
data = data , metric_names = missing_tracking_metric_names
62
71
)
63
72
return MissingMetrics (
64
- objective = {objective_name : missing_objective }
65
- if len (missing_objective ) > 0
66
- else {},
73
+ objective = {k : v for k , v in missing_objectives .items () if len (v ) > 0 },
67
74
outcome_constraints = {
68
75
k : v for k , v in missing_outcome_constraints .items () if len (v ) > 0
69
76
},
0 commit comments