Skip to content

Commit d1f4538

Browse files
blethammeta-codesync[bot]
authored andcommitted
ability to include multiple generator runs in arm effects analysis (#4963)
Summary: Pull Request resolved: #4963 Differential Revision: D94707551
1 parent e32c80c commit d1f4538

File tree

7 files changed

+274
-67
lines changed

7 files changed

+274
-67
lines changed

ax/analysis/plotly/arm_effects.py

Lines changed: 140 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ax.analysis.plotly.color_constants import BOTORCH_COLOR_SCALE
1515
from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card
1616
from ax.analysis.plotly.utils import (
17+
generator_run_key_to_color,
1718
get_arm_tooltip,
1819
get_trial_statuses_with_fallback,
1920
get_trial_trace_name,
@@ -38,6 +39,7 @@
3839
from ax.core.arm import Arm
3940
from ax.core.data import sort_by_trial_index_and_arm_name
4041
from ax.core.experiment import Experiment
42+
from ax.core.generator_run import GeneratorRun
4143
from ax.core.trial_status import TrialStatus
4244
from ax.generation_strategy.generation_strategy import GenerationStrategy
4345
from plotly import graph_objects as go
@@ -76,6 +78,7 @@ def __init__(
7678
trial_index: int | None = None,
7779
trial_statuses: Sequence[TrialStatus] | None = None,
7880
additional_arms: Sequence[Arm] | None = None,
81+
generator_runs: Mapping[str, GeneratorRun] | None = None,
7982
label: str | None = None,
8083
) -> None:
8184
"""
@@ -95,6 +98,10 @@ def __init__(
9598
additional_arms: If present, include these arms in the plot in addition to
9699
the arms in the experiment. These arms will be marked as belonging to a
97100
trial with index -1.
101+
generator_runs: If present, a mapping from name to GeneratorRun. Each
102+
GeneratorRun's arms will be plotted as a separate group with distinct
103+
colors and legend entries. Unnamed arms will be labeled as
104+
``{key}_0``, ``{key}_1``, etc.
98105
label: A label to use in the plot in place of the metric name.
99106
"""
100107

@@ -108,6 +115,7 @@ def __init__(
108115
)
109116
)
110117
self.additional_arms = additional_arms
118+
self.generator_runs = generator_runs
111119
self.label = label
112120

113121
@override
@@ -186,6 +194,7 @@ def compute(
186194
trial_index=self.trial_index,
187195
trial_statuses=self.trial_statuses,
188196
additional_arms=self.additional_arms,
197+
generator_runs=self.generator_runs,
189198
relativize=self.relativize,
190199
)
191200

@@ -255,6 +264,7 @@ def compute_arm_effects_adhoc(
255264
trial_index: int | None = None,
256265
trial_statuses: Sequence[TrialStatus] | None = None,
257266
additional_arms: Sequence[Arm] | None = None,
267+
generator_runs: Mapping[str, GeneratorRun] | None = None,
258268
labels: Mapping[str, str] | None = None,
259269
) -> AnalysisCardGroup:
260270
"""
@@ -299,6 +309,7 @@ def compute_arm_effects_adhoc(
299309
trial_index=trial_index,
300310
trial_statuses=trial_statuses,
301311
additional_arms=additional_arms,
312+
generator_runs=generator_runs,
302313
label=labels.get(metric_name) if labels is not None else None,
303314
).compute_or_error_card(
304315
experiment=experiment,
@@ -314,6 +325,58 @@ def compute_arm_effects_adhoc(
314325
)
315326

316327

328+
def _build_scatter(
329+
trial_df: pd.DataFrame,
330+
metric_name: str,
331+
is_relative: bool,
332+
status_quo_arm_name: str | None,
333+
color: str,
334+
ci_color: str,
335+
trace_name: str,
336+
showlegend: bool,
337+
legendgroup: str | None,
338+
) -> tuple[go.Scatter, list[str], list[str]] | None:
339+
"""Build a scatter trace for a group of arms.
340+
341+
Returns (scatter, arm_order_entries, arm_label_entries), or None if no
342+
valid data points exist.
343+
"""
344+
xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()]
345+
if xy_df.empty:
346+
return None
347+
if is_relative and status_quo_arm_name is not None:
348+
xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name]
349+
if xy_df.empty:
350+
return None
351+
352+
if not trial_df[f"{metric_name}_sem"].isna().all():
353+
error_y = {
354+
"type": "data",
355+
"array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"],
356+
"color": ci_color,
357+
}
358+
else:
359+
error_y = None
360+
361+
text = xy_df.apply(
362+
lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1
363+
)
364+
365+
scatter = go.Scatter(
366+
x=xy_df["x_key_order"],
367+
y=xy_df[f"{metric_name}_mean"],
368+
error_y=error_y,
369+
mode="markers",
370+
marker={"color": color},
371+
name=trace_name,
372+
showlegend=showlegend,
373+
hoverinfo="text",
374+
text=text,
375+
legendgroup=legendgroup,
376+
)
377+
return scatter, xy_df["x_key_order"].to_list(), xy_df["arm_name"].to_list()
378+
379+
317380
def _prepare_figure(
318381
df: pd.DataFrame,
319382
metric_name: str,
@@ -350,76 +413,94 @@ def _prepare_figure(
350413
num_non_candidate_trials = 0
351414
candidate_trial_marker = None
352415

416+
# --- Trial loop ---
353417
for trial_index in trial_indices:
354-
trial_df = df[df["trial_index"] == trial_index]
355-
xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()]
356-
# Skip trials with no valid data points as they will not end up in the plot
357-
if xy_df.empty:
418+
trial_df = df[
419+
(df["trial_index"] == trial_index) & (df["generator_run_key"].isna())
420+
]
421+
if trial_df.empty:
358422
continue
359-
if is_relative and status_quo_arm_name is not None:
360-
# Exclude status quo arms from relativized plots, since arms are relative
361-
# with respect to the status quo.
362-
xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name]
363-
364-
arm_order = arm_order + xy_df["x_key_order"].to_list()
365-
arm_label = arm_label + xy_df["arm_name"].to_list()
366-
if not trial_df[f"{metric_name}_sem"].isna().all():
367-
error_y = {
368-
"type": "data",
369-
"array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"],
370-
"color": trial_index_to_color(
371-
trial_df=trial_df,
372-
trials_list=trials_list,
373-
trial_index=trial_index,
374-
transparent=True,
375-
),
376-
}
377-
else:
378-
error_y = None
379-
380-
marker = {
381-
"color": trial_index_to_color(
382-
trial_df=trial_df,
383-
trials_list=trials_list,
384-
trial_index=trial_index,
385-
transparent=False,
386-
),
387-
}
388-
389-
if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name:
423+
color = trial_index_to_color(
424+
trial_df=trial_df,
425+
trials_list=trials_list,
426+
trial_index=trial_index,
427+
transparent=False,
428+
)
429+
ci_color = trial_index_to_color(
430+
trial_df=trial_df,
431+
trials_list=trials_list,
432+
trial_index=trial_index,
433+
transparent=True,
434+
)
435+
is_candidate = (
436+
not trial_df.empty
437+
and trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name
438+
)
439+
result = _build_scatter(
440+
trial_df=trial_df,
441+
metric_name=metric_name,
442+
is_relative=is_relative,
443+
status_quo_arm_name=status_quo_arm_name,
444+
color=color,
445+
ci_color=ci_color,
446+
trace_name=get_trial_trace_name(trial_index=trial_index),
447+
showlegend=False, # Will be set after determining use_colorscale
448+
legendgroup="candidate_trials" if is_candidate else None,
449+
)
450+
if result is None:
451+
continue
452+
scatter, order, labels = result
453+
scatters.append(scatter)
454+
scatter_trial_indices.append(trial_index)
455+
arm_order += order
456+
arm_label += labels
457+
if is_candidate:
390458
num_candidate_trials += 1
391-
candidate_trial_marker = marker
459+
candidate_trial_marker = {"color": color}
392460
else:
393461
num_non_candidate_trials += 1
394462

395-
text = xy_df.apply(
396-
lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1
463+
# --- Generator run loop ---
464+
unique_gr_keys = df["generator_run_key"].dropna().unique().tolist()
465+
generator_run_scatters: list[go.Scatter] = []
466+
for gr_key in unique_gr_keys:
467+
gr_df = df[df["generator_run_key"] == gr_key]
468+
color = generator_run_key_to_color(
469+
generator_run_key=gr_key,
470+
all_generator_run_keys=unique_gr_keys,
471+
transparent=False,
397472
)
398-
399-
scatters.append(
400-
go.Scatter(
401-
x=xy_df["x_key_order"],
402-
y=xy_df[f"{metric_name}_mean"],
403-
error_y=error_y,
404-
mode="markers",
405-
marker=marker,
406-
name=get_trial_trace_name(trial_index=trial_index),
407-
showlegend=False, # Will be set after determining use_colorscale
408-
hoverinfo="text",
409-
text=text,
410-
legendgroup="candidate_trials"
411-
if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name
412-
else None,
413-
)
473+
ci_color = generator_run_key_to_color(
474+
generator_run_key=gr_key,
475+
all_generator_run_keys=unique_gr_keys,
476+
transparent=True,
414477
)
415-
scatter_trial_indices.append(trial_index)
478+
result = _build_scatter(
479+
trial_df=gr_df,
480+
metric_name=metric_name,
481+
is_relative=is_relative,
482+
status_quo_arm_name=status_quo_arm_name,
483+
color=color,
484+
ci_color=ci_color,
485+
trace_name=gr_key,
486+
showlegend=True,
487+
legendgroup=None,
488+
)
489+
if result is None:
490+
continue
491+
scatter, order, labels = result
492+
generator_run_scatters.append(scatter)
493+
arm_order += order
494+
arm_label += labels
416495

417496
# Determine use_colorscale based on actual included trials
418497
use_colorscale = num_non_candidate_trials > 10
419498

420499
# Update markers and legend settings based on use_colorscale
421500
for scatter, trial_index in zip(scatters, scatter_trial_indices):
422-
trial_df = df[df["trial_index"] == trial_index]
501+
trial_df = df[
502+
(df["trial_index"] == trial_index) & (df["generator_run_key"].isna())
503+
]
423504

424505
if use_colorscale:
425506
# Add colorscale settings to marker
@@ -445,6 +526,9 @@ def _prepare_figure(
445526
trial_df["trial_status"].iloc[0] != TrialStatus.CANDIDATE.name
446527
)
447528

529+
# Append generator run scatters (not subject to colorscale)
530+
scatters.extend(generator_run_scatters)
531+
448532
# get the max length of x-ticker (arm name) to set the xaxis label and
449533
# legend position
450534
# This assumes the x-tickers are rotated 90 degrees (vertical) so legend

ax/analysis/plotly/color_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@
4242
COLOR_FOR_DECREASES: str = METRIC_CONTINUOUS_COLOR_SCALE[2] # brown
4343

4444
DISCRETE_ARM_SCALE = px.colors.qualitative.Alphabet
45+
GENERATOR_RUN_COLOR_SCALE: list[str] = px.colors.qualitative.Plotly

ax/analysis/plotly/tests/test_arm_effects.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from ax.analysis.plotly.arm_effects import ArmEffectsPlot, compute_arm_effects_adhoc
1313
from ax.api.client import Client
1414
from ax.api.configs import RangeParameterConfig
15+
from ax.core.analysis_card import AnalysisCard
1516
from ax.core.arm import Arm
17+
from ax.core.generator_run import GeneratorRun
1618
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
1719
from ax.exceptions.core import UserInputError
1820
from ax.utils.common.testutils import TestCase
@@ -201,6 +203,57 @@ def test_compute_adhoc(self) -> None:
201203

202204
self.assertEqual(cards, adhoc_cards.children[0])
203205

206+
def test_compute_with_generator_runs(self) -> None:
207+
gr = GeneratorRun(
208+
arms=[
209+
Arm(parameters={"x1": 0.1, "x2": 0.2}),
210+
Arm(parameters={"x1": 0.3, "x2": 0.4}),
211+
]
212+
)
213+
analysis = ArmEffectsPlot(
214+
metric_name="foo",
215+
use_model_predictions=True,
216+
generator_runs={"my_gr": gr},
217+
)
218+
card = analysis.compute(
219+
experiment=self.client._experiment,
220+
generation_strategy=self.client._generation_strategy,
221+
)
222+
# Check that generator run arms appear with the expected names
223+
ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"]
224+
self.assertIn("my_gr_0", ticktext)
225+
self.assertIn("my_gr_1", ticktext)
226+
227+
def test_compute_with_additional_arms_and_generator_runs(self) -> None:
228+
additional_arm = Arm(parameters={"x1": 0.5, "x2": 0.5}, name="extra_arm")
229+
gr = GeneratorRun(arms=[Arm(parameters={"x1": 0.1, "x2": 0.2})])
230+
analysis = ArmEffectsPlot(
231+
metric_name="foo",
232+
use_model_predictions=True,
233+
additional_arms=[additional_arm],
234+
generator_runs={"my_gr": gr},
235+
)
236+
card = analysis.compute(
237+
experiment=self.client._experiment,
238+
generation_strategy=self.client._generation_strategy,
239+
)
240+
ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"]
241+
self.assertIn("extra_arm", ticktext)
242+
self.assertIn("my_gr_0", ticktext)
243+
244+
def test_compute_adhoc_with_generator_runs(self) -> None:
245+
gr = GeneratorRun(arms=[Arm(parameters={"x1": 0.1, "x2": 0.2})])
246+
cards = compute_arm_effects_adhoc(
247+
experiment=self.client._experiment,
248+
generation_strategy=self.client._generation_strategy,
249+
metric_names=["foo"],
250+
generator_runs={"my_gr": gr},
251+
)
252+
self.assertEqual(len(cards.children), 1)
253+
card = assert_is_instance(cards.children[0], AnalysisCard)
254+
ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"]
255+
self.assertIn("my_gr_0", ticktext)
256+
204257
@TestCase.ax_long_test(
205258
reason=(
206259
"Adapter.predict still too slow under @mock_botorch_optimize for this test"

ax/analysis/plotly/tests/test_scatter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_compute_raw(self) -> None:
129129
"trial_status",
130130
"status_reason",
131131
"generation_node",
132+
"generator_run_key",
132133
"p_feasible_mean",
133134
"p_feasible_sem",
134135
"foo_mean",
@@ -188,6 +189,7 @@ def test_compute_with_modeled(self) -> None:
188189
"trial_status",
189190
"status_reason",
190191
"generation_node",
192+
"generator_run_key",
191193
"p_feasible_mean",
192194
"p_feasible_sem",
193195
"foo_mean",

0 commit comments

Comments
 (0)