Skip to content

Commit e9979bc

Browse files
mgrange1998facebook-github-bot
authored andcommitted
Support ScalarizedOutcomeConstraint in _prepare_p_feasible (#4856)
Summary: Implementing support for ScalarizedOutcomeConstraint in _prepare_p_feasible and _prepare_p_feasible_per_constraint as specified in the the task T235432214 1. Modified _prepare_p_feasible (lines 621-642): Replaced the old oc_names loop with direct handling that uses _get_scalarized_constraint_mean_and_sem for ScalarizedOutcomeConstraint instances. This removes the TODO comment for T235432214 2. Modified _prepare_p_feasible_per_constraint (lines 706-749): Applied the same pattern to properly compute mean/sigma for scalarized constraints using the helper function. Reviewed By: ItsMrLin Differential Revision: D92185974
1 parent ce4dc42 commit e9979bc

2 files changed

Lines changed: 145 additions & 42 deletions

File tree

ax/analysis/tests/test_utils.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,22 @@
1010
import numpy as np
1111
import pandas as pd
1212
from ax.analysis.plotly.utils import STALE_FAIL_REASON, truncate_label
13-
from ax.analysis.utils import _relativize_df_with_sq, prepare_arm_data
13+
from ax.analysis.utils import (
14+
_get_scalarized_constraint_mean_and_sem,
15+
_prepare_p_feasible,
16+
_relativize_df_with_sq,
17+
prepare_arm_data,
18+
)
1419
from ax.api.client import Client
1520
from ax.api.configs import RangeParameterConfig
1621
from ax.core.arm import Arm
1722
from ax.core.batch_trial import BatchTrial
1823
from ax.core.data import relativize_dataframe
1924
from ax.core.experiment import Experiment
2025
from ax.core.metric import Metric
26+
from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint
2127
from ax.core.trial_status import TrialStatus # noqa
28+
from ax.core.types import ComparisonOp
2229
from ax.exceptions.core import UserInputError
2330
from ax.utils.common.testutils import TestCase
2431
from ax.utils.testing.core_stubs import get_offline_experiments, get_online_experiments
@@ -865,3 +872,58 @@ def test_offline(self) -> None:
865872
trial_index=trial_index,
866873
additional_arms=additional_arms,
867874
)
875+
876+
def test_scalarized_constraints(self) -> None:
877+
df = pd.DataFrame(
878+
{
879+
"trial_index": [0, 0],
880+
"arm_name": ["arm1", "arm2"],
881+
"m1_mean": [5.0, 15.0],
882+
"m1_sem": [1.0, 1.0],
883+
"m2_mean": [5.0, 15.0],
884+
"m2_sem": [1.0, 1.0],
885+
"regular_mean": [8.0, 12.0],
886+
"regular_sem": [0.5, 0.5],
887+
}
888+
)
889+
890+
scalarized_constraint = ScalarizedOutcomeConstraint(
891+
metrics=[Metric(name="m1"), Metric(name="m2")],
892+
weights=[1.0, 1.0],
893+
op=ComparisonOp.LEQ,
894+
bound=25.0,
895+
relative=False,
896+
)
897+
898+
# Helper math: mean = w1*m1 + w2*m2, SEM = sqrt(w1^2*s1^2 + w2^2*s2^2)
899+
mean, sem = _get_scalarized_constraint_mean_and_sem(df, scalarized_constraint)
900+
np.testing.assert_array_almost_equal(mean, [10.0, 30.0])
901+
np.testing.assert_array_almost_equal(sem, [np.sqrt(2), np.sqrt(2)])
902+
903+
# Missing metric returns NaN mean and zero SEM
904+
missing_constraint = ScalarizedOutcomeConstraint(
905+
metrics=[Metric(name="m1"), Metric(name="missing")],
906+
weights=[1.0, 1.0],
907+
op=ComparisonOp.LEQ,
908+
bound=10.0,
909+
)
910+
mean, sem = _get_scalarized_constraint_mean_and_sem(df, missing_constraint)
911+
self.assertTrue(np.all(np.isnan(mean)))
912+
np.testing.assert_array_equal(sem, np.zeros(2))
913+
914+
# p_feasible with mixed regular + scalarized constraints
915+
regular_constraint = OutcomeConstraint(
916+
metric=Metric(name="regular"),
917+
op=ComparisonOp.LEQ,
918+
bound=10.0,
919+
relative=False,
920+
)
921+
p_feasible = _prepare_p_feasible(
922+
df=df,
923+
status_quo_df=None,
924+
outcome_constraints=[regular_constraint, scalarized_constraint],
925+
)
926+
self.assertFalse(p_feasible.isna().any())
927+
# arm1 (regular=8, scalarized=10) more feasible than
928+
# arm2 (regular=12, scalarized=30)
929+
self.assertGreater(p_feasible.iloc[0], p_feasible.iloc[1])

ax/analysis/utils.py

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Sequence
1010

1111
import numpy as np
12+
import numpy.typing as npt
1213
import pandas as pd
1314
import torch
1415
from ax.adapter.base import Adapter
@@ -540,6 +541,54 @@ def _extract_generation_node_name(trial: BaseTrial, arm: Arm) -> str:
540541
return Keys.UNKNOWN_GENERATION_NODE.value
541542

542543

544+
def _get_scalarized_constraint_mean_and_sem(
545+
df: pd.DataFrame,
546+
constraint: ScalarizedOutcomeConstraint,
547+
) -> tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]:
548+
"""
549+
Compute the combined mean and SEM for a ScalarizedOutcomeConstraint.
550+
551+
For independent random variables:
552+
combined_mean = sum(weight_i * mean_i)
553+
combined_sem = sqrt(sum((weight_i * sem_i)^2))
554+
555+
Args:
556+
df: DataFrame with "{metric_name}_mean" and "{metric_name}_sem" columns.
557+
constraint: The ScalarizedOutcomeConstraint.
558+
559+
Returns:
560+
Tuple of (combined_mean, combined_sem) as numpy arrays.
561+
If any component metric is missing, mean is NaN and sem is 0.
562+
"""
563+
n_rows = len(df)
564+
combined_mean = np.zeros(n_rows)
565+
combined_var = np.zeros(n_rows)
566+
all_metrics_present = True
567+
568+
for metric, weight in constraint.metric_weights:
569+
mean_col = f"{metric.name}_mean"
570+
sem_col = f"{metric.name}_sem"
571+
572+
if mean_col in df.columns:
573+
combined_mean += weight * df[mean_col].values
574+
else:
575+
all_metrics_present = False
576+
break
577+
578+
if sem_col in df.columns:
579+
metric_sem = df[sem_col].fillna(0).values
580+
else:
581+
metric_sem = np.zeros(n_rows)
582+
583+
combined_var += (weight**2) * (metric_sem**2)
584+
585+
if not all_metrics_present:
586+
# Match existing pattern: mean=NaN, sem=0 for missing data
587+
return np.full(n_rows, np.nan), np.zeros(n_rows)
588+
589+
return combined_mean, np.sqrt(combined_var)
590+
591+
543592
def _prepare_p_feasible(
544593
df: pd.DataFrame,
545594
status_quo_df: pd.DataFrame | None,
@@ -571,34 +620,27 @@ def _prepare_p_feasible(
571620
return pd.Series(np.ones(len(df)))
572621

573622
# If an arm is missing data for a metric leave the mean as NaN.
574-
oc_names = []
575-
for oc in outcome_constraints:
576-
if isinstance(oc, ScalarizedOutcomeConstraint):
577-
# take the str representation of the scalarized outcome constraint
578-
oc_names.append(str(oc))
579-
else:
580-
oc_names.append(oc.metric.name)
581-
582-
assert len(oc_names) == len(outcome_constraints)
583-
584623
means = []
585624
sigmas = []
586-
for i, oc_name in enumerate(oc_names):
587-
df_constraint = none_throws(rel_df if outcome_constraints[i].relative else df)
588-
# TODO[T235432214]: currently we are leaving the mean as NaN if the constraint
589-
# is on ScalarizedOutcomeConstraint but we should be able to calculate it by
590-
# setting the mean to be weights * individual metrics and sem to be
591-
# sqrt(sum((weights * individual_sems)^2)), assuming independence.
592-
if f"{oc_name}_mean" in df_constraint.columns:
593-
means.append(df_constraint[f"{oc_name}_mean"].tolist())
625+
for oc in outcome_constraints:
626+
df_constraint = none_throws(rel_df if oc.relative else df)
594627

628+
if isinstance(oc, ScalarizedOutcomeConstraint):
629+
mean, sem = _get_scalarized_constraint_mean_and_sem(df_constraint, oc)
630+
means.append(mean.tolist())
631+
sigmas.append(sem.tolist())
595632
else:
596-
means.append([float("nan")] * len(df_constraint))
597-
sigmas.append(
598-
(df_constraint[f"{oc_name}_sem"].fillna(0)).tolist()
599-
if f"{oc_name}_sem" in df_constraint.columns
600-
else [0] * len(df)
601-
)
633+
metric_name = oc.metric.name
634+
if f"{metric_name}_mean" in df_constraint.columns:
635+
means.append(df_constraint[f"{metric_name}_mean"].tolist())
636+
else:
637+
means.append([float("nan")] * len(df_constraint))
638+
639+
sigmas.append(
640+
(df_constraint[f"{metric_name}_sem"].fillna(0)).tolist()
641+
if f"{metric_name}_sem" in df_constraint.columns
642+
else [0] * len(df)
643+
)
602644

603645
con_lower_inds = [
604646
i
@@ -665,28 +707,27 @@ def _prepare_p_feasible_per_constraint(
665707
if len(outcome_constraints) == 0:
666708
return pd.DataFrame(index=df.index)
667709

668-
oc_names = []
669-
for oc in outcome_constraints:
670-
if isinstance(oc, ScalarizedOutcomeConstraint):
671-
oc_names.append(str(oc))
672-
else:
673-
oc_names.append(oc.metric.name)
674-
675710
result_df = pd.DataFrame(index=df.index)
676711
# Compute probability for each constraint individually
677-
for oc_name, oc in zip(oc_names, outcome_constraints):
712+
for oc in outcome_constraints:
678713
df_constraint = none_throws(rel_df if oc.relative else df)
679714

680-
# Get mean and sigma for this constraint
681-
if f"{oc_name}_mean" in df_constraint.columns:
682-
mean = df_constraint[f"{oc_name}_mean"].values
715+
if isinstance(oc, ScalarizedOutcomeConstraint):
716+
mean, sigma = _get_scalarized_constraint_mean_and_sem(df_constraint, oc)
717+
oc_display_name = str(oc)
683718
else:
684-
mean = np.nan * np.ones(len(df_constraint))
719+
metric_name = oc.metric.name
720+
oc_display_name = metric_name
685721

686-
if f"{oc_name}_sem" in df_constraint.columns:
687-
sigma = df_constraint[f"{oc_name}_sem"].fillna(0).values
688-
else:
689-
sigma = np.zeros(len(df))
722+
if f"{metric_name}_mean" in df_constraint.columns:
723+
mean = df_constraint[f"{metric_name}_mean"].values
724+
else:
725+
mean = np.full(len(df_constraint), np.nan)
726+
727+
if f"{metric_name}_sem" in df_constraint.columns:
728+
sigma = df_constraint[f"{metric_name}_sem"].fillna(0).values
729+
else:
730+
sigma = np.zeros(len(df))
690731

691732
# Convert to torch tensors (shape: [n_arms, 1])
692733
mean_tensor = torch.tensor(mean, dtype=torch.double).unsqueeze(-1)
@@ -706,7 +747,7 @@ def _prepare_p_feasible_per_constraint(
706747

707748
# Convert back to numpy and store in result dataframe
708749
prob = log_prob.exp().squeeze().numpy()
709-
result_df[f"p_feasible_{oc_name}"] = prob
750+
result_df[f"p_feasible_{oc_display_name}"] = prob
710751

711752
return result_df
712753

0 commit comments

Comments
 (0)