Skip to content

Commit 3654123

Browse files
blethamfacebook-github-bot
authored andcommitted
ability to include multiple generator runs in arm effects analysis (#4963)
Summary: Pull Request resolved: #4963 Reviewed By: mpolson64 Differential Revision: D94707551
1 parent 7437e76 commit 3654123

7 files changed

Lines changed: 275 additions & 67 deletions

File tree

ax/analysis/plotly/arm_effects.py

Lines changed: 140 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ax.analysis.plotly.color_constants import BOTORCH_COLOR_SCALE
1515
from ax.analysis.plotly.plotly_analysis import create_plotly_analysis_card
1616
from ax.analysis.plotly.utils import (
17+
generator_run_key_to_color,
1718
get_arm_tooltip,
1819
get_trial_statuses_with_fallback,
1920
get_trial_trace_name,
@@ -42,6 +43,7 @@
4243
from ax.core.arm import Arm
4344
from ax.core.data import sort_by_trial_index_and_arm_name
4445
from ax.core.experiment import Experiment
46+
from ax.core.generator_run import GeneratorRun
4547
from ax.core.trial_status import TrialStatus
4648
from ax.generation_strategy.generation_strategy import GenerationStrategy
4749
from plotly import graph_objects as go
@@ -80,6 +82,7 @@ def __init__(
8082
trial_index: int | None = None,
8183
trial_statuses: Sequence[TrialStatus] | None = None,
8284
additional_arms: Sequence[Arm] | None = None,
85+
generator_runs: Mapping[str, GeneratorRun] | None = None,
8386
label: str | None = None,
8487
) -> None:
8588
"""
@@ -99,6 +102,10 @@ def __init__(
99102
additional_arms: If present, include these arms in the plot in addition to
100103
the arms in the experiment. These arms will be marked as belonging to a
101104
trial with index -1.
105+
generator_runs: If present, a mapping from name to GeneratorRun. Each
106+
GeneratorRun's arms will be plotted as a separate group with distinct
107+
colors and legend entries. Unnamed arms will be labeled as
108+
``{key}_0``, ``{key}_1``, etc.
102109
label: A label to use in the plot in place of the metric name.
103110
"""
104111

@@ -112,6 +119,7 @@ def __init__(
112119
)
113120
)
114121
self.additional_arms = additional_arms
122+
self.generator_runs = generator_runs
115123
self.label = label
116124

117125
@override
@@ -190,6 +198,7 @@ def compute(
190198
trial_index=self.trial_index,
191199
trial_statuses=self.trial_statuses,
192200
additional_arms=self.additional_arms,
201+
generator_runs=self.generator_runs,
193202
relativize=self.relativize,
194203
)
195204

@@ -259,6 +268,7 @@ def compute_arm_effects_adhoc(
259268
trial_index: int | None = None,
260269
trial_statuses: Sequence[TrialStatus] | None = None,
261270
additional_arms: Sequence[Arm] | None = None,
271+
generator_runs: Mapping[str, GeneratorRun] | None = None,
262272
labels: Mapping[str, str] | None = None,
263273
) -> AnalysisCardGroup:
264274
"""
@@ -303,6 +313,7 @@ def compute_arm_effects_adhoc(
303313
trial_index=trial_index,
304314
trial_statuses=trial_statuses,
305315
additional_arms=additional_arms,
316+
generator_runs=generator_runs,
306317
label=labels.get(metric_name) if labels is not None else None,
307318
).compute_or_error_card(
308319
experiment=experiment,
@@ -318,6 +329,58 @@ def compute_arm_effects_adhoc(
318329
)
319330

320331

332+
def _build_scatter(
333+
trial_df: pd.DataFrame,
334+
metric_name: str,
335+
is_relative: bool,
336+
status_quo_arm_name: str | None,
337+
color: str,
338+
ci_color: str,
339+
trace_name: str,
340+
showlegend: bool,
341+
legendgroup: str | None,
342+
) -> tuple[go.Scatter, list[str], list[str]] | None:
343+
"""Build a scatter trace for a group of arms.
344+
345+
Returns (scatter, arm_order_entries, arm_label_entries), or None if no
346+
valid data points exist.
347+
"""
348+
xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()]
349+
if xy_df.empty:
350+
return None
351+
if is_relative and status_quo_arm_name is not None:
352+
xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name]
353+
if xy_df.empty:
354+
return None
355+
356+
if not trial_df[f"{metric_name}_sem"].isna().all():
357+
error_y = {
358+
"type": "data",
359+
"array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"],
360+
"color": ci_color,
361+
}
362+
else:
363+
error_y = None
364+
365+
text = xy_df.apply(
366+
lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1
367+
)
368+
369+
scatter = go.Scatter(
370+
x=xy_df["x_key_order"],
371+
y=xy_df[f"{metric_name}_mean"],
372+
error_y=error_y,
373+
mode="markers",
374+
marker={"color": color},
375+
name=trace_name,
376+
showlegend=showlegend,
377+
hoverinfo="text",
378+
text=text,
379+
legendgroup=legendgroup,
380+
)
381+
return scatter, xy_df["x_key_order"].to_list(), xy_df["arm_name"].to_list()
382+
383+
321384
def _prepare_figure(
322385
df: pd.DataFrame,
323386
metric_name: str,
@@ -354,76 +417,94 @@ def _prepare_figure(
354417
num_non_candidate_trials = 0
355418
candidate_trial_marker = None
356419

420+
# --- Trial loop ---
357421
for trial_index in trial_indices:
358-
trial_df = df[df["trial_index"] == trial_index]
359-
xy_df = trial_df[~trial_df[f"{metric_name}_mean"].isna()]
360-
# Skip trials with no valid data points as they will not end up in the plot
361-
if xy_df.empty:
422+
trial_df = df[
423+
(df["trial_index"] == trial_index) & (df["generator_run_key"].isna())
424+
]
425+
if trial_df.empty:
362426
continue
363-
if is_relative and status_quo_arm_name is not None:
364-
# Exclude status quo arms from relativized plots, since arms are relative
365-
# with respect to the status quo.
366-
xy_df = xy_df[xy_df["arm_name"] != status_quo_arm_name]
367-
368-
arm_order = arm_order + xy_df["x_key_order"].to_list()
369-
arm_label = arm_label + xy_df["arm_name"].to_list()
370-
if not trial_df[f"{metric_name}_sem"].isna().all():
371-
error_y = {
372-
"type": "data",
373-
"array": Z_SCORE_95_CI * xy_df[f"{metric_name}_sem"],
374-
"color": trial_index_to_color(
375-
trial_df=trial_df,
376-
trials_list=trials_list,
377-
trial_index=trial_index,
378-
transparent=True,
379-
),
380-
}
381-
else:
382-
error_y = None
383-
384-
marker = {
385-
"color": trial_index_to_color(
386-
trial_df=trial_df,
387-
trials_list=trials_list,
388-
trial_index=trial_index,
389-
transparent=False,
390-
),
391-
}
392-
393-
if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name:
427+
color = trial_index_to_color(
428+
trial_df=trial_df,
429+
trials_list=trials_list,
430+
trial_index=trial_index,
431+
transparent=False,
432+
)
433+
ci_color = trial_index_to_color(
434+
trial_df=trial_df,
435+
trials_list=trials_list,
436+
trial_index=trial_index,
437+
transparent=True,
438+
)
439+
is_candidate = (
440+
not trial_df.empty
441+
and trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name
442+
)
443+
result = _build_scatter(
444+
trial_df=trial_df,
445+
metric_name=metric_name,
446+
is_relative=is_relative,
447+
status_quo_arm_name=status_quo_arm_name,
448+
color=color,
449+
ci_color=ci_color,
450+
trace_name=get_trial_trace_name(trial_index=trial_index),
451+
showlegend=False, # Will be set after determining use_colorscale
452+
legendgroup="candidate_trials" if is_candidate else None,
453+
)
454+
if result is None:
455+
continue
456+
scatter, order, labels = result
457+
scatters.append(scatter)
458+
scatter_trial_indices.append(trial_index)
459+
arm_order += order
460+
arm_label += labels
461+
if is_candidate:
394462
num_candidate_trials += 1
395-
candidate_trial_marker = marker
463+
candidate_trial_marker = {"color": color}
396464
else:
397465
num_non_candidate_trials += 1
398466

399-
text = xy_df.apply(
400-
lambda row: get_arm_tooltip(row=row, metric_names=[metric_name]), axis=1
467+
# --- Generator run loop ---
468+
unique_gr_keys = df["generator_run_key"].dropna().unique().tolist()
469+
generator_run_scatters: list[go.Scatter] = []
470+
for gr_key in unique_gr_keys:
471+
gr_df = df[df["generator_run_key"] == gr_key]
472+
color = generator_run_key_to_color(
473+
generator_run_key=gr_key,
474+
all_generator_run_keys=unique_gr_keys,
475+
transparent=False,
401476
)
402-
403-
scatters.append(
404-
go.Scatter(
405-
x=xy_df["x_key_order"],
406-
y=xy_df[f"{metric_name}_mean"],
407-
error_y=error_y,
408-
mode="markers",
409-
marker=marker,
410-
name=get_trial_trace_name(trial_index=trial_index),
411-
showlegend=False, # Will be set after determining use_colorscale
412-
hoverinfo="text",
413-
text=text,
414-
legendgroup="candidate_trials"
415-
if trial_df["trial_status"].iloc[0] == TrialStatus.CANDIDATE.name
416-
else None,
417-
)
477+
ci_color = generator_run_key_to_color(
478+
generator_run_key=gr_key,
479+
all_generator_run_keys=unique_gr_keys,
480+
transparent=True,
418481
)
419-
scatter_trial_indices.append(trial_index)
482+
result = _build_scatter(
483+
trial_df=gr_df,
484+
metric_name=metric_name,
485+
is_relative=is_relative,
486+
status_quo_arm_name=status_quo_arm_name,
487+
color=color,
488+
ci_color=ci_color,
489+
trace_name=gr_key,
490+
showlegend=True,
491+
legendgroup=None,
492+
)
493+
if result is None:
494+
continue
495+
scatter, order, labels = result
496+
generator_run_scatters.append(scatter)
497+
arm_order += order
498+
arm_label += labels
420499

421500
# Determine use_colorscale based on actual included trials
422501
use_colorscale = num_non_candidate_trials > 10
423502

424503
# Update markers and legend settings based on use_colorscale
425504
for scatter, trial_index in zip(scatters, scatter_trial_indices):
426-
trial_df = df[df["trial_index"] == trial_index]
505+
trial_df = df[
506+
(df["trial_index"] == trial_index) & (df["generator_run_key"].isna())
507+
]
427508

428509
if use_colorscale:
429510
# Add colorscale settings to marker
@@ -449,6 +530,9 @@ def _prepare_figure(
449530
trial_df["trial_status"].iloc[0] != TrialStatus.CANDIDATE.name
450531
)
451532

533+
# Append generator run scatters (not subject to colorscale)
534+
scatters.extend(generator_run_scatters)
535+
452536
# get the max length of x-ticker (arm name) to set the xaxis label and
453537
# legend position
454538
# This assumes the x-tickers are rotated 90 degrees (vertical) so legend

ax/analysis/plotly/color_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@
4242
COLOR_FOR_DECREASES: str = METRIC_CONTINUOUS_COLOR_SCALE[2] # brown
4343

4444
DISCRETE_ARM_SCALE = px.colors.qualitative.Alphabet
45+
GENERATOR_RUN_COLOR_SCALE: list[str] = px.colors.qualitative.Plotly

ax/analysis/plotly/tests/test_arm_effects.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from ax.analysis.plotly.arm_effects import ArmEffectsPlot, compute_arm_effects_adhoc
1313
from ax.api.client import Client
1414
from ax.api.configs import RangeParameterConfig
15+
from ax.core.analysis_card import AnalysisCard
1516
from ax.core.arm import Arm
17+
from ax.core.generator_run import GeneratorRun
1618
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
1719
from ax.exceptions.core import UserInputError
1820
from ax.utils.common.testutils import TestCase
@@ -201,6 +203,57 @@ def test_compute_adhoc(self) -> None:
201203

202204
self.assertEqual(cards, adhoc_cards.children[0])
203205

206+
def test_compute_with_generator_runs(self) -> None:
207+
gr = GeneratorRun(
208+
arms=[
209+
Arm(parameters={"x1": 0.1, "x2": 0.2}),
210+
Arm(parameters={"x1": 0.3, "x2": 0.4}),
211+
]
212+
)
213+
analysis = ArmEffectsPlot(
214+
metric_name="foo",
215+
use_model_predictions=True,
216+
generator_runs={"my_gr": gr},
217+
)
218+
card = analysis.compute(
219+
experiment=self.client._experiment,
220+
generation_strategy=self.client._generation_strategy,
221+
)
222+
# Check that generator run arms appear with the expected names
223+
ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"]
224+
self.assertIn("my_gr_0", ticktext)
225+
self.assertIn("my_gr_1", ticktext)
226+
227+
def test_compute_with_additional_arms_and_generator_runs(self) -> None:
228+
additional_arm = Arm(parameters={"x1": 0.5, "x2": 0.5}, name="extra_arm")
229+
gr = GeneratorRun(arms=[Arm(parameters={"x1": 0.1, "x2": 0.2})])
230+
analysis = ArmEffectsPlot(
231+
metric_name="foo",
232+
use_model_predictions=True,
233+
additional_arms=[additional_arm],
234+
generator_runs={"my_gr": gr},
235+
)
236+
card = analysis.compute(
237+
experiment=self.client._experiment,
238+
generation_strategy=self.client._generation_strategy,
239+
)
240+
ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"]
241+
self.assertIn("extra_arm", ticktext)
242+
self.assertIn("my_gr_0", ticktext)
243+
244+
def test_compute_adhoc_with_generator_runs(self) -> None:
245+
gr = GeneratorRun(arms=[Arm(parameters={"x1": 0.1, "x2": 0.2})])
246+
cards = compute_arm_effects_adhoc(
247+
experiment=self.client._experiment,
248+
generation_strategy=self.client._generation_strategy,
249+
metric_names=["foo"],
250+
generator_runs={"my_gr": gr},
251+
)
252+
self.assertEqual(len(cards.children), 1)
253+
card = assert_is_instance(cards.children[0], AnalysisCard)
254+
ticktext = json.loads(card.blob)["layout"]["xaxis"]["ticktext"]
255+
self.assertIn("my_gr_0", ticktext)
256+
204257
@TestCase.ax_long_test(
205258
reason=(
206259
"Adapter.predict still too slow under @mock_botorch_optimize for this test"

ax/analysis/plotly/tests/test_scatter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_compute_raw(self) -> None:
129129
"trial_status",
130130
"status_reason",
131131
"generation_node",
132+
"generator_run_key",
132133
"p_feasible_mean",
133134
"p_feasible_sem",
134135
"foo_mean",
@@ -188,6 +189,7 @@ def test_compute_with_modeled(self) -> None:
188189
"trial_status",
189190
"status_reason",
190191
"generation_node",
192+
"generator_run_key",
191193
"p_feasible_mean",
192194
"p_feasible_sem",
193195
"foo_mean",

0 commit comments

Comments
 (0)