@@ -238,7 +238,6 @@ def infer_objective_thresholds(
238
238
return self ._untransform_objective_thresholds (
239
239
objective_thresholds = obj_thresholds ,
240
240
objective_weights = torch_opt_config .objective_weights ,
241
- bounds = search_space_digest .bounds ,
242
241
opt_config_metrics = torch_opt_config .opt_config_metrics ,
243
242
fixed_features = torch_opt_config .fixed_features ,
244
243
)
@@ -727,7 +726,6 @@ def _gen(
727
726
self ._untransform_objective_thresholds (
728
727
objective_thresholds = gen_metadata ["objective_thresholds" ],
729
728
objective_weights = torch_opt_config .objective_weights ,
730
- bounds = search_space_digest .bounds ,
731
729
opt_config_metrics = torch_opt_config .opt_config_metrics ,
732
730
fixed_features = torch_opt_config .fixed_features ,
733
731
)
@@ -923,24 +921,38 @@ def _untransform_objective_thresholds(
923
921
self ,
924
922
objective_thresholds : Tensor ,
925
923
objective_weights : Tensor ,
926
- bounds : list [tuple [int | float , int | float ]],
927
924
opt_config_metrics : dict [str , Metric ],
928
925
fixed_features : dict [int , float ] | None ,
929
926
) -> list [ObjectiveThreshold ]:
930
- thresholds_np = objective_thresholds .cpu ().numpy ()
927
+ """Converts tensor-valued (possibly inferred) objective thresholds to
928
+ ``ObjectiveThreshold`` objects, and untransforms to ensure they are
929
+ on the same raw scale as the original optimization config.
930
+
931
+ Args:
932
+ objective_thresholds: A tensor of (possibly inferred) objective thresholds
933
+ of shape `(num_metrics)`.
934
+ objective_weights: A tensor of objective weights that denote whether each
935
+ objective is being minimized (-1) or maximized (+1). May also include
936
+ 0 values, which represents outcome constraints and tracking metrics.
937
+ opt_config_metrics: A dictionary mapping the metric name to the ``Metric``
938
+ object from the original optimization config.
939
+ fixed_features: A map {feature_index: value} for features that should be
940
+ fixed to a particular value during generation. This typically includes
941
+ the target trial index for multi-task applications.
942
+
943
+ Returns:
944
+ A list of ``ObjectiveThreshold``s on the raw, untransformed scale.
945
+ """
931
946
idxs = objective_weights .nonzero ().view (- 1 ).tolist ()
932
947
933
- # Create transformed ObjectiveThresholds from numpy thresholds.
948
+ # Create transformed ObjectiveThresholds from tensor thresholds.
934
949
thresholds = []
935
950
for idx in idxs :
936
951
sign = torch .sign (objective_weights [idx ])
937
952
thresholds .append (
938
953
ObjectiveThreshold (
939
954
metric = opt_config_metrics [self .outcomes [idx ]],
940
- # pyre-fixme[6]: In call `ObjectiveThreshold.__init__`,
941
- # for argument `bound`, expected `float` but got
942
- # `ndarray[typing.Any, typing.Any]`
943
- bound = thresholds_np [idx ],
955
+ bound = float (objective_thresholds [idx ].item ()),
944
956
relative = False ,
945
957
op = ComparisonOp .LEQ if sign < 0 else ComparisonOp .GEQ ,
946
958
)
0 commit comments