Skip to content

Commit

Permalink
Document TorchAdapter._untransform_objective_thresholds, remove unuse…
Browse files Browse the repository at this point in the history
…d kwarg (facebook#3400)

Summary:

As titled. Just adding clarity to methods that make me question "what does this do, what is it used for?"

Differential Revision: D69989096
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 21, 2025
1 parent 63a1eaf commit 80f9d77
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def infer_objective_thresholds(
return self._untransform_objective_thresholds(
objective_thresholds=obj_thresholds,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
opt_config_metrics=torch_opt_config.opt_config_metrics,
fixed_features=torch_opt_config.fixed_features,
)
Expand Down Expand Up @@ -727,7 +726,6 @@ def _gen(
self._untransform_objective_thresholds(
objective_thresholds=gen_metadata["objective_thresholds"],
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
opt_config_metrics=torch_opt_config.opt_config_metrics,
fixed_features=torch_opt_config.fixed_features,
)
Expand Down Expand Up @@ -923,24 +921,38 @@ def _untransform_objective_thresholds(
self,
objective_thresholds: Tensor,
objective_weights: Tensor,
bounds: list[tuple[int | float, int | float]],
opt_config_metrics: dict[str, Metric],
fixed_features: dict[int, float] | None,
) -> list[ObjectiveThreshold]:
thresholds_np = objective_thresholds.cpu().numpy()
"""Converts tensor-valued (possibly inferred) objective thresholds to
``ObjectiveThreshold`` objects, and untransforms to ensure they are
on the same raw scale as the original optimization config.
Args:
objective_thresholds: A tensor of (possibly inferred) objective thresholds
of shape `(num_metrics)`.
objective_weights: A tensor of objective weights that denote whether each
objective is being minimized (-1) or maximized (+1). May also include
0 values, which represents outcome constraints and tracking metrics.
opt_config_metrics: A dictionary mapping the metric name to the ``Metric``
object from the original optimization config.
fixed_features: A map {feature_index: value} for features that should be
fixed to a particular value during generation. This typically includes
the target trial index for multi-task applications.
Returns:
A list of ``ObjectiveThreshold``s on the raw, untransformed scale.
"""
idxs = objective_weights.nonzero().view(-1).tolist()

# Create transformed ObjectiveThresholds from numpy thresholds.
# Create transformed ObjectiveThresholds from tensor thresholds.
thresholds = []
for idx in idxs:
sign = torch.sign(objective_weights[idx])
thresholds.append(
ObjectiveThreshold(
metric=opt_config_metrics[self.outcomes[idx]],
# pyre-fixme[6]: In call `ObjectiveThreshold.__init__`,
# for argument `bound`, expected `float` but got
# `ndarray[typing.Any, typing.Any]`
bound=thresholds_np[idx],
bound=float(objective_thresholds[idx].item()),
relative=False,
op=ComparisonOp.LEQ if sign < 0 else ComparisonOp.GEQ,
)
Expand Down

0 comments on commit 80f9d77

Please sign in to comment.