Skip to content

Commit 80f9d77

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Document TorchAdapter._untransform_objective_thresholds, remove unused kwarg (facebook#3400)
Summary: As titled. Just adding clarity to methods that make me question "what does this do, what is it used for?" Differential Revision: D69989096
1 parent 63a1eaf commit 80f9d77

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

ax/modelbridge/torch.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def infer_objective_thresholds(
238238
return self._untransform_objective_thresholds(
239239
objective_thresholds=obj_thresholds,
240240
objective_weights=torch_opt_config.objective_weights,
241-
bounds=search_space_digest.bounds,
242241
opt_config_metrics=torch_opt_config.opt_config_metrics,
243242
fixed_features=torch_opt_config.fixed_features,
244243
)
@@ -727,7 +726,6 @@ def _gen(
727726
self._untransform_objective_thresholds(
728727
objective_thresholds=gen_metadata["objective_thresholds"],
729728
objective_weights=torch_opt_config.objective_weights,
730-
bounds=search_space_digest.bounds,
731729
opt_config_metrics=torch_opt_config.opt_config_metrics,
732730
fixed_features=torch_opt_config.fixed_features,
733731
)
@@ -923,24 +921,38 @@ def _untransform_objective_thresholds(
923921
self,
924922
objective_thresholds: Tensor,
925923
objective_weights: Tensor,
926-
bounds: list[tuple[int | float, int | float]],
927924
opt_config_metrics: dict[str, Metric],
928925
fixed_features: dict[int, float] | None,
929926
) -> list[ObjectiveThreshold]:
930-
thresholds_np = objective_thresholds.cpu().numpy()
927+
"""Converts tensor-valued (possibly inferred) objective thresholds to
928+
``ObjectiveThreshold`` objects, and untransforms to ensure they are
929+
on the same raw scale as the original optimization config.
930+
931+
Args:
932+
objective_thresholds: A tensor of (possibly inferred) objective thresholds
933+
of shape `(num_metrics)`.
934+
objective_weights: A tensor of objective weights that denote whether each
935+
objective is being minimized (-1) or maximized (+1). May also include
936+
0 values, which represents outcome constraints and tracking metrics.
937+
opt_config_metrics: A dictionary mapping the metric name to the ``Metric``
938+
object from the original optimization config.
939+
fixed_features: A map {feature_index: value} for features that should be
940+
fixed to a particular value during generation. This typically includes
941+
the target trial index for multi-task applications.
942+
943+
Returns:
944+
A list of ``ObjectiveThreshold``s on the raw, untransformed scale.
945+
"""
931946
idxs = objective_weights.nonzero().view(-1).tolist()
932947

933-
# Create transformed ObjectiveThresholds from numpy thresholds.
948+
# Create transformed ObjectiveThresholds from tensor thresholds.
934949
thresholds = []
935950
for idx in idxs:
936951
sign = torch.sign(objective_weights[idx])
937952
thresholds.append(
938953
ObjectiveThreshold(
939954
metric=opt_config_metrics[self.outcomes[idx]],
940-
# pyre-fixme[6]: In call `ObjectiveThreshold.__init__`,
941-
# for argument `bound`, expected `float` but got
942-
# `ndarray[typing.Any, typing.Any]`
943-
bound=thresholds_np[idx],
955+
bound=float(objective_thresholds[idx].item()),
944956
relative=False,
945957
op=ComparisonOp.LEQ if sign < 0 else ComparisonOp.GEQ,
946958
)

0 commit comments

Comments
 (0)