1414from ax .analysis .plotly .color_constants import BOTORCH_COLOR_SCALE
1515from ax .analysis .plotly .plotly_analysis import create_plotly_analysis_card
1616from ax .analysis .plotly .utils import (
17+ generator_run_key_to_color ,
1718 get_arm_tooltip ,
1819 get_trial_statuses_with_fallback ,
1920 get_trial_trace_name ,
4243from ax .core .arm import Arm
4344from ax .core .data import sort_by_trial_index_and_arm_name
4445from ax .core .experiment import Experiment
46+ from ax .core .generator_run import GeneratorRun
4547from ax .core .trial_status import TrialStatus
4648from ax .generation_strategy .generation_strategy import GenerationStrategy
4749from plotly import graph_objects as go
@@ -80,6 +82,7 @@ def __init__(
8082 trial_index : int | None = None ,
8183 trial_statuses : Sequence [TrialStatus ] | None = None ,
8284 additional_arms : Sequence [Arm ] | None = None ,
85+ generator_runs : Mapping [str , GeneratorRun ] | None = None ,
8386 label : str | None = None ,
8487 ) -> None :
8588 """
@@ -99,6 +102,10 @@ def __init__(
99102 additional_arms: If present, include these arms in the plot in addition to
100103 the arms in the experiment. These arms will be marked as belonging to a
101104 trial with index -1.
105+ generator_runs: If present, a mapping from name to GeneratorRun. Each
106+ GeneratorRun's arms will be plotted as a separate group with distinct
107+ colors and legend entries. Unnamed arms will be labeled as
108+ ``{key}_0``, ``{key}_1``, etc.
102109 label: A label to use in the plot in place of the metric name.
103110 """
104111
@@ -112,6 +119,7 @@ def __init__(
112119 )
113120 )
114121 self .additional_arms = additional_arms
122+ self .generator_runs = generator_runs
115123 self .label = label
116124
117125 @override
@@ -190,6 +198,7 @@ def compute(
190198 trial_index = self .trial_index ,
191199 trial_statuses = self .trial_statuses ,
192200 additional_arms = self .additional_arms ,
201+ generator_runs = self .generator_runs ,
193202 relativize = self .relativize ,
194203 )
195204
@@ -259,6 +268,7 @@ def compute_arm_effects_adhoc(
259268 trial_index : int | None = None ,
260269 trial_statuses : Sequence [TrialStatus ] | None = None ,
261270 additional_arms : Sequence [Arm ] | None = None ,
271+ generator_runs : Mapping [str , GeneratorRun ] | None = None ,
262272 labels : Mapping [str , str ] | None = None ,
263273) -> AnalysisCardGroup :
264274 """
@@ -303,6 +313,7 @@ def compute_arm_effects_adhoc(
303313 trial_index = trial_index ,
304314 trial_statuses = trial_statuses ,
305315 additional_arms = additional_arms ,
316+ generator_runs = generator_runs ,
306317 label = labels .get (metric_name ) if labels is not None else None ,
307318 ).compute_or_error_card (
308319 experiment = experiment ,
@@ -318,6 +329,58 @@ def compute_arm_effects_adhoc(
318329 )
319330
320331
332+ def _build_scatter (
333+ trial_df : pd .DataFrame ,
334+ metric_name : str ,
335+ is_relative : bool ,
336+ status_quo_arm_name : str | None ,
337+ color : str ,
338+ ci_color : str ,
339+ trace_name : str ,
340+ showlegend : bool ,
341+ legendgroup : str | None ,
342+ ) -> tuple [go .Scatter , list [str ], list [str ]] | None :
343+ """Build a scatter trace for a group of arms.
344+
345+ Returns (scatter, arm_order_entries, arm_label_entries), or None if no
346+ valid data points exist.
347+ """
348+ xy_df = trial_df [~ trial_df [f"{ metric_name } _mean" ].isna ()]
349+ if xy_df .empty :
350+ return None
351+ if is_relative and status_quo_arm_name is not None :
352+ xy_df = xy_df [xy_df ["arm_name" ] != status_quo_arm_name ]
353+ if xy_df .empty :
354+ return None
355+
356+ if not trial_df [f"{ metric_name } _sem" ].isna ().all ():
357+ error_y = {
358+ "type" : "data" ,
359+ "array" : Z_SCORE_95_CI * xy_df [f"{ metric_name } _sem" ],
360+ "color" : ci_color ,
361+ }
362+ else :
363+ error_y = None
364+
365+ text = xy_df .apply (
366+ lambda row : get_arm_tooltip (row = row , metric_names = [metric_name ]), axis = 1
367+ )
368+
369+ scatter = go .Scatter (
370+ x = xy_df ["x_key_order" ],
371+ y = xy_df [f"{ metric_name } _mean" ],
372+ error_y = error_y ,
373+ mode = "markers" ,
374+ marker = {"color" : color },
375+ name = trace_name ,
376+ showlegend = showlegend ,
377+ hoverinfo = "text" ,
378+ text = text ,
379+ legendgroup = legendgroup ,
380+ )
381+ return scatter , xy_df ["x_key_order" ].to_list (), xy_df ["arm_name" ].to_list ()
382+
383+
321384def _prepare_figure (
322385 df : pd .DataFrame ,
323386 metric_name : str ,
@@ -354,76 +417,94 @@ def _prepare_figure(
354417 num_non_candidate_trials = 0
355418 candidate_trial_marker = None
356419
420+ # --- Trial loop ---
357421 for trial_index in trial_indices :
358- trial_df = df [df [ "trial_index" ] == trial_index ]
359- xy_df = trial_df [ ~ trial_df [ f" { metric_name } _mean " ].isna ()]
360- # Skip trials with no valid data points as they will not end up in the plot
361- if xy_df .empty :
422+ trial_df = df [
423+ ( df [ "trial_index" ] == trial_index ) & ( df [ "generator_run_key " ].isna ())
424+ ]
425+ if trial_df .empty :
362426 continue
363- if is_relative and status_quo_arm_name is not None :
364- # Exclude status quo arms from relativized plots, since arms are relative
365- # with respect to the status quo.
366- xy_df = xy_df [xy_df ["arm_name" ] != status_quo_arm_name ]
367-
368- arm_order = arm_order + xy_df ["x_key_order" ].to_list ()
369- arm_label = arm_label + xy_df ["arm_name" ].to_list ()
370- if not trial_df [f"{ metric_name } _sem" ].isna ().all ():
371- error_y = {
372- "type" : "data" ,
373- "array" : Z_SCORE_95_CI * xy_df [f"{ metric_name } _sem" ],
374- "color" : trial_index_to_color (
375- trial_df = trial_df ,
376- trials_list = trials_list ,
377- trial_index = trial_index ,
378- transparent = True ,
379- ),
380- }
381- else :
382- error_y = None
383-
384- marker = {
385- "color" : trial_index_to_color (
386- trial_df = trial_df ,
387- trials_list = trials_list ,
388- trial_index = trial_index ,
389- transparent = False ,
390- ),
391- }
392-
393- if trial_df ["trial_status" ].iloc [0 ] == TrialStatus .CANDIDATE .name :
427+ color = trial_index_to_color (
428+ trial_df = trial_df ,
429+ trials_list = trials_list ,
430+ trial_index = trial_index ,
431+ transparent = False ,
432+ )
433+ ci_color = trial_index_to_color (
434+ trial_df = trial_df ,
435+ trials_list = trials_list ,
436+ trial_index = trial_index ,
437+ transparent = True ,
438+ )
439+ is_candidate = (
440+ not trial_df .empty
441+ and trial_df ["trial_status" ].iloc [0 ] == TrialStatus .CANDIDATE .name
442+ )
443+ result = _build_scatter (
444+ trial_df = trial_df ,
445+ metric_name = metric_name ,
446+ is_relative = is_relative ,
447+ status_quo_arm_name = status_quo_arm_name ,
448+ color = color ,
449+ ci_color = ci_color ,
450+ trace_name = get_trial_trace_name (trial_index = trial_index ),
451+ showlegend = False , # Will be set after determining use_colorscale
452+ legendgroup = "candidate_trials" if is_candidate else None ,
453+ )
454+ if result is None :
455+ continue
456+ scatter , order , labels = result
457+ scatters .append (scatter )
458+ scatter_trial_indices .append (trial_index )
459+ arm_order += order
460+ arm_label += labels
461+ if is_candidate :
394462 num_candidate_trials += 1
395- candidate_trial_marker = marker
463+ candidate_trial_marker = { "color" : color }
396464 else :
397465 num_non_candidate_trials += 1
398466
399- text = xy_df .apply (
400- lambda row : get_arm_tooltip (row = row , metric_names = [metric_name ]), axis = 1
467+ # --- Generator run loop ---
468+ unique_gr_keys = df ["generator_run_key" ].dropna ().unique ().tolist ()
469+ generator_run_scatters : list [go .Scatter ] = []
470+ for gr_key in unique_gr_keys :
471+ gr_df = df [df ["generator_run_key" ] == gr_key ]
472+ color = generator_run_key_to_color (
473+ generator_run_key = gr_key ,
474+ all_generator_run_keys = unique_gr_keys ,
475+ transparent = False ,
401476 )
402-
403- scatters .append (
404- go .Scatter (
405- x = xy_df ["x_key_order" ],
406- y = xy_df [f"{ metric_name } _mean" ],
407- error_y = error_y ,
408- mode = "markers" ,
409- marker = marker ,
410- name = get_trial_trace_name (trial_index = trial_index ),
411- showlegend = False , # Will be set after determining use_colorscale
412- hoverinfo = "text" ,
413- text = text ,
414- legendgroup = "candidate_trials"
415- if trial_df ["trial_status" ].iloc [0 ] == TrialStatus .CANDIDATE .name
416- else None ,
417- )
477+ ci_color = generator_run_key_to_color (
478+ generator_run_key = gr_key ,
479+ all_generator_run_keys = unique_gr_keys ,
480+ transparent = True ,
418481 )
419- scatter_trial_indices .append (trial_index )
482+ result = _build_scatter (
483+ trial_df = gr_df ,
484+ metric_name = metric_name ,
485+ is_relative = is_relative ,
486+ status_quo_arm_name = status_quo_arm_name ,
487+ color = color ,
488+ ci_color = ci_color ,
489+ trace_name = gr_key ,
490+ showlegend = True ,
491+ legendgroup = None ,
492+ )
493+ if result is None :
494+ continue
495+ scatter , order , labels = result
496+ generator_run_scatters .append (scatter )
497+ arm_order += order
498+ arm_label += labels
420499
421500 # Determine use_colorscale based on actual included trials
422501 use_colorscale = num_non_candidate_trials > 10
423502
424503 # Update markers and legend settings based on use_colorscale
425504 for scatter , trial_index in zip (scatters , scatter_trial_indices ):
426- trial_df = df [df ["trial_index" ] == trial_index ]
505+ trial_df = df [
506+ (df ["trial_index" ] == trial_index ) & (df ["generator_run_key" ].isna ())
507+ ]
427508
428509 if use_colorscale :
429510 # Add colorscale settings to marker
@@ -449,6 +530,9 @@ def _prepare_figure(
449530 trial_df ["trial_status" ].iloc [0 ] != TrialStatus .CANDIDATE .name
450531 )
451532
533+ # Append generator run scatters (not subject to colorscale)
534+ scatters .extend (generator_run_scatters )
535+
452536 # get the max length of x-ticker (arm name) to set the xaxis label and
453537 # legend position
454538 # This assumes the x-tickers are rotated 90 degrees (vertical) so legend
0 commit comments