From 1773a2a51f0b3ae31bcaac74eef54007b9b01686 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Mon, 23 Mar 2026 06:47:50 -0700 Subject: [PATCH 1/2] add metric prediction summary to cross validation analysis Differential Revision: D94553707 --- ax/analysis/plotly/cross_validation.py | 48 +++++++++++++++++++ .../plotly/tests/test_cross_validation.py | 23 +++++++-- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index 23abfcabd49..9c63e9423a2 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -13,6 +13,7 @@ from ax.adapter.base import Adapter from ax.adapter.cross_validation import cross_validate, CVResult from ax.analysis.analysis import Analysis +from ax.analysis.healthcheck.predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD from ax.analysis.plotly.color_constants import AX_BLUE from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card from ax.analysis.plotly.utils import get_scatter_point_color, Z_SCORE_95_CI @@ -106,6 +107,7 @@ def __init__( self.untransform = untransform self.trial_index = trial_index self.labels: dict[str, str] = {**labels} if labels is not None else {} + self._r2s: dict[str, float] = {} @override def validate_applicable_state( @@ -144,6 +146,7 @@ def compute( relevant_adapter._experiment.signature_to_metric[signature].name for signature in relevant_adapter._metric_signatures ] + self._r2s = {} for metric_name in self.metric_names or relevant_adapter_metric_names: df = _prepare_data( metric_name=metric_name, cv_results=cv_results, adapter=relevant_adapter @@ -162,6 +165,7 @@ def compute( y_obs=df["observed"].to_numpy(), y_pred=df["predicted"].to_numpy(), ) + self._r2s[metric_title] = r_squared # Define the cross-validation description based on the number of folds cv_description = ( @@ -202,6 +206,50 @@ def compute( cards.append(card) + # Create a summary table of R2 values for all metrics + if self._r2s: + threshold = DEFAULT_MODEL_FIT_THRESHOLD + metric_names_list = list(self._r2s.keys()) + r2_values = [f"{v:.2f}" for v in self._r2s.values()] + fill_colors = [ + "rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white" + for r2 in self._r2s.values() + ] + r2_fig = go.Figure( + data=[ + go.Table( + columnwidth=[4, 1], + header={ + "values": ["Metric", "R\u00b2"], + "align": "left", + }, + cells={ + "values": [metric_names_list, r2_values], + "align": "left", + "fill_color": [fill_colors, fill_colors], + }, + ) + ] + ) + r2_card = create_plotly_analysis_card( + name=self.__class__.__name__, + title="Summary of model fits", + subtitle=( + "R\u00b2 (coefficient of determination) measures how well" + " the model predicts each metric. Higher values indicate" + " better model fit. Metrics with R\u00b2 >=" + f" {threshold} are highlighted in green." + ), + df=pd.DataFrame( + { + "Metric": metric_names_list, + "R\u00b2": list(self._r2s.values()), + } + ), + fig=r2_fig, + ) + cards.append(r2_card) + return self._create_analysis_card_group( title=CV_CARDGROUP_TITLE, subtitle=CV_CARDGROUP_SUBTITLE, diff --git a/ax/analysis/plotly/tests/test_cross_validation.py b/ax/analysis/plotly/tests/test_cross_validation.py index b7004f819e9..f186b469167 100644 --- a/ax/analysis/plotly/tests/test_cross_validation.py +++ b/ax/analysis/plotly/tests/test_cross_validation.py @@ -65,9 +65,12 @@ def test_compute(self, mock_r2: mock.Mock) -> None: ): analysis.compute() - (card,) = analysis.compute( + cards = analysis.compute( generation_strategy=self.client.generation_strategy ).flatten() + # Should have the CV plot card and the R2 summary card + self.assertEqual(len(cards), 2) + card = cards[0] self.assertEqual( card.name, "CrossValidationPlot", @@ -106,6 +109,15 @@ def test_compute(self, mock_r2: mock.Mock) -> None: ) self.assertIsNotNone(card.blob) + # Assert that _r2s is populated after compute + self.assertIn("bar", analysis._r2s) + self.assertAlmostEqual(analysis._r2s["bar"], 0.85) + + # Assert the R2 summary card + r2_card = cards[1] + self.assertEqual(r2_card.name, "CrossValidationPlot") + self.assertEqual(r2_card.title, "Summary of model fits") + # Assert that all arms are in the cross validation df # because trial index is not specified for t in self.client.experiment.trials.values(): @@ -121,9 +133,10 @@ def test_compute(self, mock_r2: mock.Mock) -> None: def test_it_can_specify_trial_index_correctly(self) -> None: analysis = CrossValidationPlot(metric_names=["bar"], trial_index=9) - (card,) = analysis.compute( + cards = analysis.compute( generation_strategy=self.client.generation_strategy ).flatten() + card = cards[0] for t in self.client.experiment.trials.values(): # Skip the last trial because the model was used to generate it # and therefore hasn't observed it @@ -159,15 +172,17 @@ def test_compute_adhoc(self, mock_r2: mock.Mock) -> None: cards = compute_cross_validation_adhoc( adapter=adapter, labels=metric_mapping ).flatten() - self.assertEqual(len(cards), 2) + self.assertEqual(len(cards), 3) titles = { "Cross Validation for spunky (R\u00b2 = 0.85)", "Cross Validation for foo2 (R\u00b2 = 0.85)", } - for card in cards: + for card in cards[:2]: self.assertEqual(card.name, "CrossValidationPlot") self.assertIn(card.title, titles) titles.remove(card.title) + # The last card is the R2 summary + self.assertEqual(cards[2].title, "Summary of model fits") @TestCase.ax_long_test( reason=( From 2cfa19ed753935b6633232633fdc8ffb0986d5ec Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Mon, 23 Mar 2026 07:44:30 -0700 Subject: [PATCH 2/2] ability to include multiple generator runs in arm effects analysis (#4963) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/4963 Reviewed By: mpolson64 Differential Revision: D94707551 --- ax/analysis/plotly/arm_effects.py | 196 +++++++++++++------ ax/analysis/plotly/color_constants.py | 1 + ax/analysis/plotly/tests/test_arm_effects.py | 53 +++++ ax/analysis/plotly/tests/test_scatter.py | 2 + ax/analysis/plotly/utils.py | 30 ++- ax/analysis/tests/test_utils.py | 8 + ax/analysis/utils.py | 52 ++++- 7 files changed, 275 insertions(+), 67 deletions(-) diff --git a/ax/analysis/plotly/arm_effects.py b/ax/analysis/plotly/arm_effects.py index 0dda85e18da..7f94d66d901 100644 --- a/ax/analysis/plotly/arm_effects.py +++ b/ax/analysis/plotly/arm_effects.py @@ -14,6 +14,7 @@ from ax.analysis.plotly.color_constants import BOTORCH_COLOR_SCALE from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card from ax.analysis.plotly.utils import ( + generator_run_key_to_color, get_arm_tooltip, get_trial_statuses_with_fallback, get_trial_trace_name, @@ -42,6 +43,7 @@ from ax.core.arm import Arm from ax.core.data import sort_by_trial_index_and_arm_name from ax.core.experiment import Experiment +from ax.core.generator_run import GeneratorRun from ax.core.trial_status import TrialStatus from ax.generation_strategy.generation_strategy import GenerationStrategy from plotly import graph_objects as go @@ -80,6 +82,7 @@ def __init__( trial_index: int | None = None, trial_statuses: Sequence[TrialStatus] | None = None, additional_arms: Sequence[Arm] | None = None, + generator_runs: Mapping[str, GeneratorRun] | None = None, label: str | None = None, ) -> None: """ @@ -99,6 +102,10 @@ def __init__( additional_arms: If present, include these arms in the plot in addition to the arms in the experiment. These arms will be marked as belonging to a trial with index -1. + generator_runs: If present, a mapping from name to GeneratorRun. Each + GeneratorRun's arms will be plotted as a separate group with distinct + colors and legend entries. Unnamed arms will be labeled as + ``{key}_0``, ``{key}_1``, etc. label: A label to use in the plot in place of the metric name. """ @@ -112,6 +119,7 @@ def __init__( ) ) self.additional_arms = additional_arms + self.generator_runs = generator_runs self.label = label @override @@ -190,6 +198,7 @@ def compute( trial_index=self.trial_index, trial_statuses=self.trial_statuses, additional_arms=self.additional_arms, + generator_runs=self.generator_runs, relativize=self.relativize, ) @@ -259,6 +268,7 @@ def compute_arm_effects_adhoc( trial_index: int | None = None, trial_statuses: Sequence[TrialStatus] | None = None, additional_arms: Sequence[Arm] | None = None, + generator_runs: Mapping[str, GeneratorRun] | None = None, labels: Mapping[str, str] | None = None, ) -> AnalysisCardGroup: """ @@ -303,6 +313,7 @@ def compute_arm_effects_adhoc( trial_index=trial_index, trial_statuses=trial_statuses, additional_arms=additional_arms, + generator_runs=generator_runs, label=labels.get(metric_name) if labels is not None else None, ).compute_or_error_card( experiment=experiment, @@ -318,6 +329,58 @@ def compute_arm_effects_adhoc( ) +def _build_scatter( + trial_df: pd.DataFrame, + metric_name: str, + is_relative: bool, + status_quo_arm_name: str | None, + color: str, + ci_color: str, + trace_name: str, + showlegend: bool, + legendgroup: str | None, +) -> tuple[go.Scatter, list[str], list[str]] | None: + """Build a scatter trace for a group of arms. + + Returns (scatter, arm_order_entries, arm_label_entries), or None if no + valid data points exist. + """ + xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()] + if xy_df.empty: + return None + if is_relative and status_quo_arm_name is not None: + xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name] + if xy_df.empty: + return None + + if not trial_df[f"{metric_name}_sem"].isna().all(): + error_y = { + "type": "data", + "array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"], + "color": ci_color, + } + else: + error_y = None + + text = xy_df.apply( + lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1 + ) + + scatter = go.Scatter( + x=xy_df["x_key_order"], + y=xy_df[f"{metric_name}_mean"], + error_y=error_y, + mode="markers", + marker={"color": color}, + name=trace_name, + showlegend=showlegend, + hoverinfo="text", + text=text, + legendgroup=legendgroup, + ) + return scatter, xy_df["x_key_order"].to_list(), xy_df["arm_name"].to_list() + + def _prepare_figure( df: pd.DataFrame, metric_name: str, @@ -354,76 +417,94 @@ def _prepare_figure( num_non_candidate_trials = 0 candidate_trial_marker = None + # --- Trial loop --- for trial_index in trial_indices: - trial_df = df[df["trial_index"] == trial_index] - xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()] - # Skip trials with no valid data points as they will not end up in the plot - if xy_df.empty: + trial_df = df[ + (df["trial_index"] == trial_index) & (df["generator_run_key"].isna()) + ] + if trial_df.empty: continue - if is_relative and status_quo_arm_name is not None: - # Exclude status quo arms from relativized plots, since arms are relative - # with respect to the status quo. - xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name] - - arm_order = arm_order + xy_df["x_key_order"].to_list() - arm_label = arm_label + xy_df["arm_name"].to_list() - if not trial_df[f"{metric_name}_sem"].isna().all(): - error_y = { - "type": "data", - "array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"], - "color": trial_index_to_color( - trial_df=trial_df, - trials_list=trials_list, - trial_index=trial_index, - transparent=True, - ), - } - else: - error_y = None - - marker = { - "color": trial_index_to_color( - trial_df=trial_df, - trials_list=trials_list, - trial_index=trial_index, - transparent=False, - ), - } - - if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name: + color = trial_index_to_color( + trial_df=trial_df, + trials_list=trials_list, + trial_index=trial_index, + transparent=False, + ) + ci_color = trial_index_to_color( + trial_df=trial_df, + trials_list=trials_list, + trial_index=trial_index, + transparent=True, + ) + is_candidate = ( + not trial_df.empty + and trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name + ) + result = _build_scatter( + trial_df=trial_df, + metric_name=metric_name, + is_relative=is_relative, + status_quo_arm_name=status_quo_arm_name, + color=color, + ci_color=ci_color, + trace_name=get_trial_trace_name(trial_index=trial_index), + showlegend=False, # Will be set after determining use_colorscale + legendgroup="candidate_trials" if is_candidate else None, + ) + if result is None: + continue + scatter, order, labels = result + scatters.append(scatter) + scatter_trial_indices.append(trial_index) + arm_order += order + arm_label += labels + if is_candidate: num_candidate_trials += 1 - candidate_trial_marker = marker + candidate_trial_marker = {"color": color} else: num_non_candidate_trials += 1 - text = xy_df.apply( - lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1 + # --- Generator run loop --- + unique_gr_keys = df["generator_run_key"].dropna().unique().tolist() + generator_run_scatters: list[go.Scatter] = [] + for gr_key in unique_gr_keys: + gr_df = df[df["generator_run_key"] == gr_key] + color = generator_run_key_to_color( + generator_run_key=gr_key, + all_generator_run_keys=unique_gr_keys, + transparent=False, ) - - scatters.append( - go.Scatter( - x=xy_df["x_key_order"], - y=xy_df[f"{metric_name}_mean"], - error_y=error_y, - mode="markers", - marker=marker, - name=get_trial_trace_name(trial_index=trial_index), - showlegend=False, # Will be set after determining use_colorscale - hoverinfo="text", - text=text, - legendgroup="candidate_trials" - if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name - else None, - ) + ci_color = generator_run_key_to_color( + generator_run_key=gr_key, + all_generator_run_keys=unique_gr_keys, + transparent=True, ) - scatter_trial_indices.append(trial_index) + result = _build_scatter( + trial_df=gr_df, + metric_name=metric_name, + is_relative=is_relative, + status_quo_arm_name=status_quo_arm_name, + color=color, + ci_color=ci_color, + trace_name=gr_key, + showlegend=True, + legendgroup=None, + ) + if result is None: + continue + scatter, order, labels = result + generator_run_scatters.append(scatter) + arm_order += order + arm_label += labels # Determine use_colorscale based on actual included trials use_colorscale = num_non_candidate_trials > 10 # Update markers and legend settings based on use_colorscale for scatter, trial_index in zip(scatters, scatter_trial_indices): - trial_df = df[df["trial_index"] == trial_index] + trial_df = df[ + (df["trial_index"] == trial_index) & (df["generator_run_key"].isna()) + ] if use_colorscale: # Add colorscale settings to marker @@ -449,6 +530,9 @@ def _prepare_figure( trial_df["trial_status"].iloc[0] != TrialStatus.CANDIDATE.name ) + # Append generator run scatters (not subject to colorscale) + scatters.extend(generator_run_scatters) + # get the max length of x-ticker (arm name) to set the xaxis label and # legend position # This assumes the x-tickers are rotated 90 degrees (vertical) so legend diff --git a/ax/analysis/plotly/color_constants.py b/ax/analysis/plotly/color_constants.py index dc1373555c6..71af75a6ef3 100644 --- a/ax/analysis/plotly/color_constants.py +++ b/ax/analysis/plotly/color_constants.py @@ -42,3 +42,4 @@ COLOR_FOR_DECREASES: str = METRIC_CONTINUOUS_COLOR_SCALE[2] # brown DISCRETE_ARM_SCALE = px.colors.qualitative.Alphabet +GENERATOR_RUN_COLOR_SCALE: list[str] = px.colors.qualitative.Plotly diff --git a/ax/analysis/plotly/tests/test_arm_effects.py b/ax/analysis/plotly/tests/test_arm_effects.py index 386a054dfbc..80b7ed3e9a5 100644 --- a/ax/analysis/plotly/tests/test_arm_effects.py +++ b/ax/analysis/plotly/tests/test_arm_effects.py @@ -12,7 +12,9 @@ from ax.analysis.plotly.arm_effects import ArmEffectsPlot, compute_arm_effects_adhoc from ax.api.client import Client from ax.api.configs import RangeParameterConfig +from ax.core.analysis_card import AnalysisCard from ax.core.arm import Arm +from ax.core.generator_run import GeneratorRun from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus from ax.exceptions.core import UserInputError from ax.utils.common.testutils import TestCase @@ -201,6 +203,57 @@ def test_compute_adhoc(self) -> None: self.assertEqual(cards, adhoc_cards.children[0]) + def test_compute_with_generator_runs(self) -> None: + gr = GeneratorRun( + arms=[ + Arm(parameters={"x1": 0.1, "x2": 0.2}), + Arm(parameters={"x1": 0.3, "x2": 0.4}), + ] + ) + analysis = ArmEffectsPlot( + metric_name="foo", + use_model_predictions=True, + generator_runs={"my_gr": gr}, + ) + card = analysis.compute( + experiment=self.client._experiment, + generation_strategy=self.client._generation_strategy, + ) + # Check that generator run arms appear with the expected names + ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"] + self.assertIn("my_gr_0", ticktext) + self.assertIn("my_gr_1", ticktext) + + def test_compute_with_additional_arms_and_generator_runs(self) -> None: + additional_arm = Arm(parameters={"x1": 0.5, "x2": 0.5}, name="extra_arm") + gr = GeneratorRun(arms=[Arm(parameters={"x1": 0.1, "x2": 0.2})]) + analysis = ArmEffectsPlot( + metric_name="foo", + use_model_predictions=True, + additional_arms=[additional_arm], + generator_runs={"my_gr": gr}, + ) + card = analysis.compute( + experiment=self.client._experiment, + generation_strategy=self.client._generation_strategy, + ) + ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"] + self.assertIn("extra_arm", ticktext) + self.assertIn("my_gr_0", ticktext) + + def test_compute_adhoc_with_generator_runs(self) -> None: + gr = GeneratorRun(arms=[Arm(parameters={"x1": 0.1, "x2": 0.2})]) + cards = compute_arm_effects_adhoc( + experiment=self.client._experiment, + generation_strategy=self.client._generation_strategy, + metric_names=["foo"], + generator_runs={"my_gr": gr}, + ) + self.assertEqual(len(cards.children), 1) + card = assert_is_instance(cards.children[0], AnalysisCard) + ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"] + self.assertIn("my_gr_0", ticktext) + @TestCase.ax_long_test( reason=( "Adapter.predict still too slow under @mock_botorch_optimize for this test" diff --git a/ax/analysis/plotly/tests/test_scatter.py b/ax/analysis/plotly/tests/test_scatter.py index 6cec98dbfdd..a2779913ea4 100644 --- a/ax/analysis/plotly/tests/test_scatter.py +++ b/ax/analysis/plotly/tests/test_scatter.py @@ -129,6 +129,7 @@ def test_compute_raw(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -188,6 +189,7 @@ def test_compute_with_modeled(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", diff --git a/ax/analysis/plotly/utils.py b/ax/analysis/plotly/utils.py index cd7c6c3fc9f..7f98ed9710d 100644 --- a/ax/analysis/plotly/utils.py +++ b/ax/analysis/plotly/utils.py @@ -10,7 +10,11 @@ from collections.abc import Sequence import pandas as pd -from ax.analysis.plotly.color_constants import BOTORCH_COLOR_SCALE, LIGHT_AX_BLUE +from ax.analysis.plotly.color_constants import ( + BOTORCH_COLOR_SCALE, + GENERATOR_RUN_COLOR_SCALE, + LIGHT_AX_BLUE, +) from ax.core.experiment import Experiment from ax.core.objective import ScalarizedObjective from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus @@ -125,6 +129,17 @@ def trial_index_to_color( return get_scatter_point_color(hex_color=hex_color, ci_transparency=transparent) +def generator_run_key_to_color( + generator_run_key: str, + all_generator_run_keys: Sequence[str], + transparent: bool, +) -> str: + """Determines the color for a generator run based on its key.""" + key_index = list(all_generator_run_keys).index(generator_run_key) + hex_color = GENERATOR_RUN_COLOR_SCALE[key_index % len(GENERATOR_RUN_COLOR_SCALE)] + return get_scatter_point_color(hex_color=hex_color, ci_transparency=transparent) + + def get_arm_tooltip( row: pd.Series, metric_names: Sequence[str], @@ -135,8 +150,10 @@ def get_arm_tooltip( """ tooltip_strs = [] trial_index = row["trial_index"] - if trial_index != -1: - # omit the trial tooltip for additional arms + generator_run_key = row.get("generator_run_key") + if pd.notna(generator_run_key): + tooltip_strs.append(f"Generator Run: {generator_run_key}") + elif trial_index != -1: tooltip_strs.append(f"Trial: {trial_index}") tooltip_strs.append(f"Arm: {row['arm_name']}") @@ -164,8 +181,13 @@ def get_arm_tooltip( return "
".join(tooltip_strs) -def get_trial_trace_name(trial_index: int) -> str: +def get_trial_trace_name( + trial_index: int, + generator_run_key: str | None = None, +) -> str: """Get a trace name for a trial index.""" + if generator_run_key is not None: + return generator_run_key return "Additional Arms" if trial_index == -1 else f"Trial {trial_index}" diff --git a/ax/analysis/tests/test_utils.py b/ax/analysis/tests/test_utils.py index 92120c0a0b0..bde31c42f37 100644 --- a/ax/analysis/tests/test_utils.py +++ b/ax/analysis/tests/test_utils.py @@ -156,6 +156,7 @@ def test_prepare_arm_data_raw(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -203,6 +204,7 @@ def test_prepare_arm_data_raw(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -234,6 +236,7 @@ def test_prepare_arm_data_raw(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -343,6 +346,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -387,6 +391,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -420,6 +425,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -459,6 +465,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", @@ -502,6 +509,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None: "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", "foo_mean", diff --git a/ax/analysis/utils.py b/ax/analysis/utils.py index 1e6e50eab22..e506ac21406 100644 --- a/ax/analysis/utils.py +++ b/ax/analysis/utils.py @@ -5,7 +5,7 @@ # pyre-strict -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from logging import Logger import numpy as np @@ -18,6 +18,7 @@ from ax.core.base_trial import BaseTrial from ax.core.batch_trial import BatchTrial from ax.core.experiment import Experiment +from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint from ax.core.trial_status import TrialStatus @@ -102,6 +103,7 @@ def prepare_arm_data( trial_index: int | None = None, trial_statuses: Sequence[TrialStatus] | None = None, additional_arms: Sequence[Arm] | None = None, + generator_runs: Mapping[str, GeneratorRun] | None = None, relativize: bool = False, compute_p_feasible_per_constraint: bool = False, ) -> pd.DataFrame: @@ -205,6 +207,7 @@ def prepare_arm_data( trial_index=trial_index, trial_statuses=trial_statuses, additional_arms=additional_arms, + generator_runs=generator_runs, target_trial_index=target_trial_index, ) else: @@ -214,6 +217,12 @@ def prepare_arm_data( "there is no observed raw data for the additional arms that are not " "part of the Experiment." ) + if generator_runs is not None: + raise UserInputError( + "Cannot provide generator_runs when use_model_predictions=False since " + "there is no observed raw data for the generator run arms that are not " + "part of the Experiment." + ) df = _prepare_raw_arm_data( metric_names=metric_names, @@ -223,6 +232,8 @@ def prepare_arm_data( target_trial_index=target_trial_index, ) raw_df = df + if "generator_run_key" not in df.columns: + df["generator_run_key"] = None has_relative_constraints = experiment.optimization_config is not None and any( oc.relative for oc in experiment.optimization_config.outcome_constraints ) @@ -251,10 +262,15 @@ def prepare_arm_data( ) # Add additional columns which do not require predicting or extracting data. - df["trial_status"] = df["trial_index"].apply( - lambda trial_index: experiment.trials[trial_index].status.name - if trial_index != -1 - else "Additional Arm" + df["trial_status"] = df.apply( + lambda row: experiment.trials[row["trial_index"]].status.name + if row["trial_index"] != -1 + else ( + row["generator_run_key"] + if pd.notna(row.get("generator_run_key")) + else "Additional Arm" + ), + axis=1, ) df["status_reason"] = df["trial_index"].apply( lambda trial_index: experiment.trials[trial_index].status_reason @@ -314,6 +330,7 @@ def prepare_arm_data( "trial_status", "status_reason", "generation_node", + "generator_run_key", "p_feasible_mean", "p_feasible_sem", } @@ -334,6 +351,7 @@ def _prepare_modeled_arm_data( trial_index: int | None = None, trial_statuses: Sequence[TrialStatus] | None = None, additional_arms: Sequence[Arm] | None = None, + generator_runs: Mapping[str, GeneratorRun] | None = None, target_trial_index: int | None = None, ) -> pd.DataFrame: """ @@ -388,13 +406,29 @@ def _prepare_modeled_arm_data( else trial.arms ) ] + # Track generator_run_key for each pair: None for trial/additional arms + generator_run_keys: list[str | None] = [ + None for _ in range(len(trial_index_arm_pairs)) + ] # Add additional arms passed in by the user - trial_index_arm_pairs += [(-1, arm) for arm in additional_arms or []] + additional_arms_list = additional_arms or [] + trial_index_arm_pairs += [(-1, arm) for arm in additional_arms_list] + generator_run_keys += [None] * len(additional_arms_list) + # Add generator run arms with naming + if generator_runs is not None: + for gr_key, gr in generator_runs.items(): + for j, arm in enumerate(gr.arms): + if not arm.has_name: + arm = Arm(parameters=arm.parameters, name=f"{gr_key}_{j}") + trial_index_arm_pairs.append((-1, arm)) + generator_run_keys.append(gr_key) # Remove arms with missing parameters since we cannot predict for them. predictable_pairs = [] unpredictable_pairs = [] - for trial_index, arm in trial_index_arm_pairs: + predictable_gr_keys: list[str | None] = [] + unpredictable_gr_keys: list[str | None] = [] + for i, (trial_index, arm) in enumerate(trial_index_arm_pairs): if adapter.model_space.check_membership( parameterization=arm.parameters, raise_error=False, @@ -402,8 +436,10 @@ def _prepare_modeled_arm_data( check_range_bounds=False, ): predictable_pairs.append((trial_index, arm)) + predictable_gr_keys.append(generator_run_keys[i]) else: unpredictable_pairs.append((trial_index, arm)) + unpredictable_gr_keys.append(generator_run_keys[i]) # Batch predict for efficiency. predictions = adapter.predict( @@ -428,6 +464,7 @@ def _prepare_modeled_arm_data( "arm_name": predictable_pairs[i][1].name if predictable_pairs[i][1].has_name else f"{Keys.UNNAMED_ARM.value}_{i}", + "generator_run_key": predictable_gr_keys[i], **{ f"{metric_name}_mean": predictions[0][metric_name][i] for metric_name in all_predicted_metrics @@ -446,6 +483,7 @@ def _prepare_modeled_arm_data( "arm_name": unpredictable_pairs[i][1].name if unpredictable_pairs[i][1].has_name else f"{Keys.UNNAMED_ARM.value}_{i}", + "generator_run_key": unpredictable_gr_keys[i], **{ f"{metric_name}_mean": None for metric_name in all_predicted_metrics },