Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document TorchAdapter._untransform_objective_thresholds, remove unused kwarg #3400

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def infer_objective_thresholds(
return self._untransform_objective_thresholds(
objective_thresholds=obj_thresholds,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
opt_config_metrics=torch_opt_config.opt_config_metrics,
fixed_features=torch_opt_config.fixed_features,
)
Expand Down Expand Up @@ -727,7 +726,6 @@ def _gen(
self._untransform_objective_thresholds(
objective_thresholds=gen_metadata["objective_thresholds"],
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
opt_config_metrics=torch_opt_config.opt_config_metrics,
fixed_features=torch_opt_config.fixed_features,
)
Expand Down Expand Up @@ -923,24 +921,38 @@ def _untransform_objective_thresholds(
self,
objective_thresholds: Tensor,
objective_weights: Tensor,
bounds: list[tuple[int | float, int | float]],
opt_config_metrics: dict[str, Metric],
fixed_features: dict[int, float] | None,
) -> list[ObjectiveThreshold]:
thresholds_np = objective_thresholds.cpu().numpy()
"""Converts tensor-valued (possibly inferred) objective thresholds to
``ObjectiveThreshold`` objects, and untransforms to ensure they are
on the same raw scale as the original optimization config.

Args:
objective_thresholds: A tensor of (possibly inferred) objective thresholds
of shape `(num_metrics)`.
objective_weights: A tensor of objective weights that denote whether each
objective is being minimized (-1) or maximized (+1). May also include
0 values, which represents outcome constraints and tracking metrics.
opt_config_metrics: A dictionary mapping the metric name to the ``Metric``
object from the original optimization config.
fixed_features: A map {feature_index: value} for features that should be
fixed to a particular value during generation. This typically includes
the target trial index for multi-task applications.

Returns:
A list of ``ObjectiveThreshold``s on the raw, untransformed scale.
"""
idxs = objective_weights.nonzero().view(-1).tolist()

# Create transformed ObjectiveThresholds from numpy thresholds.
# Create transformed ObjectiveThresholds from tensor thresholds.
thresholds = []
for idx in idxs:
sign = torch.sign(objective_weights[idx])
thresholds.append(
ObjectiveThreshold(
metric=opt_config_metrics[self.outcomes[idx]],
# pyre-fixme[6]: In call `ObjectiveThreshold.__init__`,
# for argument `bound`, expected `float` but got
# `ndarray[typing.Any, typing.Any]`
bound=thresholds_np[idx],
bound=float(objective_thresholds[idx].item()),
relative=False,
op=ComparisonOp.LEQ if sign < 0 else ComparisonOp.GEQ,
)
Expand Down