From 80f9d77844779c0d0e74e3bd82c947b889166c17 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 21 Feb 2025 15:24:46 -0800 Subject: [PATCH] Document TorchAdapter._untransform_objective_thresholds, remove unused kwarg (#3400) Summary: As titled. Just adding clarity to methods that make me question "what does this do, what is it used for?" Differential Revision: D69989096 --- ax/modelbridge/torch.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 54ee8d13811..91df07f8bc9 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -238,7 +238,6 @@ def infer_objective_thresholds( return self._untransform_objective_thresholds( objective_thresholds=obj_thresholds, objective_weights=torch_opt_config.objective_weights, - bounds=search_space_digest.bounds, opt_config_metrics=torch_opt_config.opt_config_metrics, fixed_features=torch_opt_config.fixed_features, ) @@ -727,7 +726,6 @@ def _gen( self._untransform_objective_thresholds( objective_thresholds=gen_metadata["objective_thresholds"], objective_weights=torch_opt_config.objective_weights, - bounds=search_space_digest.bounds, opt_config_metrics=torch_opt_config.opt_config_metrics, fixed_features=torch_opt_config.fixed_features, ) @@ -923,24 +921,38 @@ def _untransform_objective_thresholds( self, objective_thresholds: Tensor, objective_weights: Tensor, - bounds: list[tuple[int | float, int | float]], opt_config_metrics: dict[str, Metric], fixed_features: dict[int, float] | None, ) -> list[ObjectiveThreshold]: - thresholds_np = objective_thresholds.cpu().numpy() + """Converts tensor-valued (possibly inferred) objective thresholds to + ``ObjectiveThreshold`` objects, and untransforms to ensure they are + on the same raw scale as the original optimization config. + + Args: + objective_thresholds: A tensor of (possibly inferred) objective thresholds + of shape `(num_metrics)`. + objective_weights: A tensor of objective weights that denote whether each + objective is being minimized (-1) or maximized (+1). May also include + 0 values, which represents outcome constraints and tracking metrics. + opt_config_metrics: A dictionary mapping the metric name to the ``Metric`` + object from the original optimization config. + fixed_features: A map {feature_index: value} for features that should be + fixed to a particular value during generation. This typically includes + the target trial index for multi-task applications. + + Returns: + A list of ``ObjectiveThreshold``s on the raw, untransformed scale. + """ idxs = objective_weights.nonzero().view(-1).tolist() - # Create transformed ObjectiveThresholds from numpy thresholds. + # Create transformed ObjectiveThresholds from tensor thresholds. thresholds = [] for idx in idxs: sign = torch.sign(objective_weights[idx]) thresholds.append( ObjectiveThreshold( metric=opt_config_metrics[self.outcomes[idx]], - # pyre-fixme[6]: In call `ObjectiveThreshold.__init__`, - # for argument `bound`, expected `float` but got - # `ndarray[typing.Any, typing.Any]` - bound=thresholds_np[idx], + bound=float(objective_thresholds[idx].item()), relative=False, op=ComparisonOp.LEQ if sign < 0 else ComparisonOp.GEQ, )