6
6
7
7
# pyre-strict
8
8
9
- import copy
10
9
from copy import deepcopy
11
10
from itertools import combinations
12
11
from logging import Logger
13
- from typing import cast , Dict , List , NamedTuple , Optional , Tuple , Union
12
+ from typing import Dict , List , NamedTuple , Optional , Tuple , Union
14
13
15
14
import numpy as np
16
15
import torch
17
16
from ax .core .batch_trial import BatchTrial
18
17
from ax .core .data import Data
19
18
from ax .core .experiment import Experiment
20
19
from ax .core .metric import Metric
21
- from ax .core .objective import ScalarizedObjective
20
+ from ax .core .objective import MultiObjective , ScalarizedObjective
22
21
from ax .core .observation import ObservationFeatures
23
- from ax .core .optimization_config import (
24
- MultiObjectiveOptimizationConfig ,
25
- OptimizationConfig ,
26
- )
22
+ from ax .core .optimization_config import MultiObjectiveOptimizationConfig
27
23
from ax .core .outcome_constraint import (
28
24
ComparisonOp ,
29
25
ObjectiveThreshold ,
43
39
from ax .models .torch .posterior_mean import get_PosteriorMean
44
40
from ax .models .torch_base import TorchModel
45
41
from ax .utils .common .logger import get_logger
42
+ from ax .utils .common .typeutils import checked_cast
46
43
from ax .utils .stats .statstools import relativize
47
44
from botorch .utils .multi_objective import is_non_dominated
48
45
from botorch .utils .multi_objective .hypervolume import infer_reference_point
@@ -615,11 +612,24 @@ def infer_reference_point_from_experiment(
615
612
# when calculating the Pareto front. Also, defining a multiplier to turn all
616
613
# the objectives to be maximized. Note that the multiplier at this point
617
614
# contains 0 for outcome_constraint metrics, but this will be dropped later.
618
- dummy_rp = copy . deepcopy (
619
- experiment .optimization_config . objective_thresholds # pyre-ignore
615
+ opt_config = checked_cast (
616
+ MultiObjectiveOptimizationConfig , experiment .optimization_config
620
617
)
618
+ inferred_rp = _get_objective_thresholds (optimization_config = opt_config )
621
619
multiplier = [0 ] * len (objective_orders )
622
- for ot in dummy_rp :
620
+ if len (opt_config .objective_thresholds ) > 0 :
621
+ inferred_rp = deepcopy (opt_config .objective_thresholds )
622
+ else :
623
+ inferred_rp = []
624
+ for objective in checked_cast (MultiObjective , opt_config .objective ).objectives :
625
+ ot = ObjectiveThreshold (
626
+ metric = objective .metric ,
627
+ bound = 0.0 , # dummy value
628
+ op = ComparisonOp .LEQ if objective .minimize else ComparisonOp .GEQ ,
629
+ relative = False ,
630
+ )
631
+ inferred_rp .append (ot )
632
+ for ot in inferred_rp :
623
633
# In the following, we find the index of the objective in
624
634
# `objective_orders`. If there is an objective that does not exist
625
635
# in `obs_data`, a ValueError is raised.
@@ -640,12 +650,10 @@ def infer_reference_point_from_experiment(
640
650
modelbridge = mb_reference ,
641
651
observation_features = obs_feats ,
642
652
observation_data = obs_data ,
643
- objective_thresholds = dummy_rp ,
653
+ objective_thresholds = inferred_rp ,
644
654
use_model_predictions = False ,
645
655
)
646
-
647
656
if len (frontier_observations ) == 0 :
648
- opt_config = cast (OptimizationConfig , mb_reference ._optimization_config )
649
657
outcome_constraints = opt_config ._outcome_constraints
650
658
if len (outcome_constraints ) == 0 :
651
659
raise RuntimeError (
@@ -665,10 +673,11 @@ def infer_reference_point_from_experiment(
665
673
modelbridge = mb_reference ,
666
674
observation_features = obs_feats ,
667
675
observation_data = obs_data ,
668
- objective_thresholds = dummy_rp ,
676
+ objective_thresholds = inferred_rp ,
669
677
use_model_predictions = False ,
670
678
)
671
- opt_config ._outcome_constraints = outcome_constraints # restoring constraints
679
+ # restoring constraints
680
+ opt_config ._outcome_constraints = outcome_constraints
672
681
673
682
# Need to reshuffle columns of `f` and `obj_w` to be consistent
674
683
# with objective_orders.
@@ -698,15 +707,38 @@ def infer_reference_point_from_experiment(
698
707
x for (i , x ) in enumerate (objective_orders ) if multiplier [i ] != 0
699
708
]
700
709
701
- # Constructing the objective thresholds.
702
- # NOTE: This assumes that objective_thresholds is already initialized.
703
- nadir_objective_thresholds = copy .deepcopy (
704
- experiment .optimization_config .objective_thresholds
705
- )
706
-
707
- for obj_threshold in nadir_objective_thresholds :
710
+ for obj_threshold in inferred_rp :
708
711
obj_threshold .bound = rp [
709
712
objective_orders_reduced .index (obj_threshold .metric .name )
710
713
].item ()
714
+ return inferred_rp
711
715
712
- return nadir_objective_thresholds
716
+
717
+ def _get_objective_thresholds (
718
+ optimization_config : MultiObjectiveOptimizationConfig ,
719
+ ) -> List [ObjectiveThreshold ]:
720
+ """Get objective thresholds for an optimization config.
721
+
722
+ This will return objective thresholds with dummy values if there are
723
+ no objective thresholds on the optimization config.
724
+
725
+ Args:
726
+ optimization_config: Optimization config.
727
+
728
+ Returns:
729
+ List of objective thresholds.
730
+ """
731
+ if optimization_config .objective_thresholds is not None :
732
+ return deepcopy (optimization_config .objective_thresholds )
733
+ objective_thresholds = []
734
+ for objective in checked_cast (
735
+ MultiObjective , optimization_config .objective
736
+ ).objectives :
737
+ ot = ObjectiveThreshold (
738
+ metric = objective .metric ,
739
+ bound = 0.0 , # dummy value
740
+ op = ComparisonOp .LEQ if objective .minimize else ComparisonOp .GEQ ,
741
+ relative = False ,
742
+ )
743
+ objective_thresholds .append (ot )
744
+ return objective_thresholds
0 commit comments